From 10e3514c962b58cbbee994257872a626ff76d51b Mon Sep 17 00:00:00 2001 From: cyy Date: Sat, 9 Aug 2025 02:21:22 +0000 Subject: [PATCH] Remove tensorexpr tests (#158928) The tests are not maintained. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158928 Approved by: https://github.com/albanD, https://github.com/malfet --- .ci/pytorch/build.sh | 4 - .ci/pytorch/test.sh | 10 - aten/src/ATen/test/thread_init_test.cpp | 11 +- caffe2/CMakeLists.txt | 4 - test/cpp/tensorexpr/CMakeLists.txt | 83 - test/cpp/tensorexpr/README.md | 55 - test/cpp/tensorexpr/gtest_assert_float_eq.h | 119 - test/cpp/tensorexpr/padded_buffer.cpp | 37 - test/cpp/tensorexpr/padded_buffer.h | 242 - test/cpp/tensorexpr/test_approx.cpp | 96 - test/cpp/tensorexpr/test_aten.cpp | 1068 --- test/cpp/tensorexpr/test_base.h | 89 - test/cpp/tensorexpr/test_boundsinference.cpp | 1019 --- test/cpp/tensorexpr/test_conv.cpp | 234 - test/cpp/tensorexpr/test_cpp_codegen.cpp | 259 - test/cpp/tensorexpr/test_cuda.cpp | 2344 ------ test/cpp/tensorexpr/test_dynamic_shapes.cpp | 701 -- test/cpp/tensorexpr/test_expr.cpp | 836 -- test/cpp/tensorexpr/test_external_calls.cpp | 1061 --- test/cpp/tensorexpr/test_graph_opt.cpp | 319 - test/cpp/tensorexpr/test_ir_printer.cpp | 98 - test/cpp/tensorexpr/test_ir_verifier.cpp | 191 - test/cpp/tensorexpr/test_kernel.cpp | 2133 ----- test/cpp/tensorexpr/test_llvm.cpp | 1799 ----- test/cpp/tensorexpr/test_loopnest.cpp | 6894 ----------------- test/cpp/tensorexpr/test_memdependency.cpp | 3252 -------- test/cpp/tensorexpr/test_memplanning.cpp | 708 -- test/cpp/tensorexpr/test_ops.cpp | 78 - test/cpp/tensorexpr/test_quantization.cpp | 452 -- test/cpp/tensorexpr/test_reductions.cpp | 1928 ----- test/cpp/tensorexpr/test_registerizer.cpp | 3702 --------- test/cpp/tensorexpr/test_simplify.cpp | 5680 -------------- test/cpp/tensorexpr/test_te_fuser_pass.cpp | 402 - test/cpp/tensorexpr/test_type.cpp | 202 - .../tensorexpr/test_type_specializations.cpp | 75 - test/cpp/tensorexpr/test_utils.h | 78 - test/cpp/tensorexpr/tutorial.cpp | 542 -- test/test_jit_fuser_te.py | 5 +- torch/csrc/jit/runtime/static/ops.cpp | 2 +- 39 files changed, 10 insertions(+), 36802 deletions(-) delete mode 100644 test/cpp/tensorexpr/CMakeLists.txt delete mode 100644 test/cpp/tensorexpr/README.md delete mode 100644 test/cpp/tensorexpr/gtest_assert_float_eq.h delete mode 100644 test/cpp/tensorexpr/padded_buffer.cpp delete mode 100644 test/cpp/tensorexpr/padded_buffer.h delete mode 100644 test/cpp/tensorexpr/test_approx.cpp delete mode 100644 test/cpp/tensorexpr/test_aten.cpp delete mode 100644 test/cpp/tensorexpr/test_base.h delete mode 100644 test/cpp/tensorexpr/test_boundsinference.cpp delete mode 100644 test/cpp/tensorexpr/test_conv.cpp delete mode 100644 test/cpp/tensorexpr/test_cpp_codegen.cpp delete mode 100644 test/cpp/tensorexpr/test_cuda.cpp delete mode 100644 test/cpp/tensorexpr/test_dynamic_shapes.cpp delete mode 100644 test/cpp/tensorexpr/test_expr.cpp delete mode 100644 test/cpp/tensorexpr/test_external_calls.cpp delete mode 100644 test/cpp/tensorexpr/test_graph_opt.cpp delete mode 100644 test/cpp/tensorexpr/test_ir_printer.cpp delete mode 100644 test/cpp/tensorexpr/test_ir_verifier.cpp delete mode 100644 test/cpp/tensorexpr/test_kernel.cpp delete mode 100644 test/cpp/tensorexpr/test_llvm.cpp delete mode 100644 test/cpp/tensorexpr/test_loopnest.cpp delete mode 100644 test/cpp/tensorexpr/test_memdependency.cpp delete mode 100644 test/cpp/tensorexpr/test_memplanning.cpp delete mode 100644 test/cpp/tensorexpr/test_ops.cpp delete mode 100644 test/cpp/tensorexpr/test_quantization.cpp delete mode 100644 test/cpp/tensorexpr/test_reductions.cpp delete mode 100644 test/cpp/tensorexpr/test_registerizer.cpp delete mode 100644 test/cpp/tensorexpr/test_simplify.cpp delete mode 100644 test/cpp/tensorexpr/test_te_fuser_pass.cpp delete mode 100644 test/cpp/tensorexpr/test_type.cpp delete mode 100644 test/cpp/tensorexpr/test_type_specializations.cpp delete mode 100644 test/cpp/tensorexpr/test_utils.h delete mode 100644 test/cpp/tensorexpr/tutorial.cpp diff --git a/.ci/pytorch/build.sh b/.ci/pytorch/build.sh index c7d2cb93a64b..65f97389324a 100755 --- a/.ci/pytorch/build.sh +++ b/.ci/pytorch/build.sh @@ -50,9 +50,6 @@ if [[ ${BUILD_ENVIRONMENT} == *"parallelnative"* ]]; then export ATEN_THREADING=NATIVE fi -# Enable LLVM dependency for TensorExpr testing -export USE_LLVM=/opt/llvm -export LLVM_DIR=/opt/llvm/lib/cmake/llvm if ! which conda; then # In ROCm CIs, we are doing cross compilation on build machines with @@ -192,7 +189,6 @@ if [[ "$BUILD_ENVIRONMENT" == *-clang*-asan* ]]; then export USE_ASAN=1 export REL_WITH_DEB_INFO=1 export UBSAN_FLAGS="-fno-sanitize-recover=all" - unset USE_LLVM fi if [[ "${BUILD_ENVIRONMENT}" == *no-ops* ]]; then diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 84d40a2e458a..473a125475c4 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -1051,20 +1051,10 @@ test_libtorch_api() { mkdir -p $TEST_REPORTS_DIR OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="${MNIST_DIR}" "$TORCH_BIN_DIR"/test_api --gtest_filter='-IMethodTest.*' --gtest_output=xml:$TEST_REPORTS_DIR/test_api.xml - "$TORCH_BIN_DIR"/test_tensorexpr --gtest_output=xml:$TEST_REPORTS_DIR/test_tensorexpr.xml else # Exclude IMethodTest that relies on torch::deploy, which will instead be ran in test_deploy OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="${MNIST_DIR}" python test/run_test.py --cpp --verbose -i cpp/test_api -k "not IMethodTest" - # On s390x, pytorch is built without llvm. - # Even if it would be built with llvm, llvm currently doesn't support used features on s390x and - # test fails with errors like: - # JIT session error: Unsupported target machine architecture in ELF object pytorch-jitted-objectbuffer - # unknown file: Failure - # C++ exception with description "valOrErr INTERNAL ASSERT FAILED at "/var/lib/jenkins/workspace/torch/csrc/jit/tensorexpr/llvm_jit.h":34, please report a bug to PyTorch. Unexpected failure in LLVM JIT: Failed to materialize symbols: { (main, { func }) } - if [[ "${BUILD_ENVIRONMENT}" != *s390x* ]]; then - python test/run_test.py --cpp --verbose -i cpp/test_tensorexpr - fi fi # quantization is not fully supported on s390x yet diff --git a/aten/src/ATen/test/thread_init_test.cpp b/aten/src/ATen/test/thread_init_test.cpp index 7ad7a18e9c66..60dd52d1dffc 100644 --- a/aten/src/ATen/test/thread_init_test.cpp +++ b/aten/src/ATen/test/thread_init_test.cpp @@ -1,7 +1,8 @@ +#include + #include #include #include -#include #include @@ -9,7 +10,7 @@ // numbers of threads set and also whether the scheduler // will throw an exception when multiple threads call // their first parallel construct. -void test(int given_num_threads) { +static void test(int given_num_threads) { auto t = at::ones({1000 * 1000}, at::CPU(at::kFloat)); ASSERT_TRUE(given_num_threads >= 0); ASSERT_EQ(at::get_num_threads(), given_num_threads); @@ -19,7 +20,7 @@ void test(int given_num_threads) { } } -int main() { +TEST(ThreadInitTest, ThreadInit) { at::init_num_threads(); at::set_num_threads(4); @@ -32,13 +33,11 @@ int main() { #if !AT_PARALLEL_NATIVE at::set_num_threads(5); - ASSERT_TRUE(at::get_num_threads() == 5); + ASSERT_EQ(at::get_num_threads(), 5); #endif // test inter-op settings at::set_num_interop_threads(5); ASSERT_EQ(at::get_num_interop_threads(), 5); ASSERT_ANY_THROW(at::set_num_interop_threads(6)); - - return 0; } diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index c346cedbcf51..96ed0c3b918e 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1345,10 +1345,6 @@ if(BUILD_TEST) add_subdirectory(${TORCH_ROOT}/test/cpp/jit ${CMAKE_BINARY_DIR}/test_jit) add_subdirectory(${TORCH_ROOT}/test/cpp/nativert ${CMAKE_BINARY_DIR}/test_nativert) add_subdirectory(${TORCH_ROOT}/test/inductor ${CMAKE_BINARY_DIR}/test_inductor) - add_subdirectory( - ${TORCH_ROOT}/test/cpp/tensorexpr - ${CMAKE_BINARY_DIR}/test_tensorexpr - ) if(USE_DISTRIBUTED) add_subdirectory(${TORCH_ROOT}/test/cpp/c10d ${CMAKE_BINARY_DIR}/test_cpp_c10d) if(NOT WIN32) diff --git a/test/cpp/tensorexpr/CMakeLists.txt b/test/cpp/tensorexpr/CMakeLists.txt deleted file mode 100644 index 8fe6ffd525e9..000000000000 --- a/test/cpp/tensorexpr/CMakeLists.txt +++ /dev/null @@ -1,83 +0,0 @@ -set(TENSOREXPR_TEST_ROOT ${TORCH_ROOT}/test/cpp/tensorexpr) - -set(TENSOREXPR_TEST_SRCS - ${TENSOREXPR_TEST_ROOT}/test_approx.cpp - ${TENSOREXPR_TEST_ROOT}/test_aten.cpp - ${TENSOREXPR_TEST_ROOT}/test_boundsinference.cpp - ${TENSOREXPR_TEST_ROOT}/test_conv.cpp - ${TENSOREXPR_TEST_ROOT}/test_cpp_codegen.cpp - ${TENSOREXPR_TEST_ROOT}/test_dynamic_shapes.cpp - ${TENSOREXPR_TEST_ROOT}/test_expr.cpp - ${TENSOREXPR_TEST_ROOT}/test_external_calls.cpp - ${TENSOREXPR_TEST_ROOT}/test_graph_opt.cpp - ${TENSOREXPR_TEST_ROOT}/test_ir_printer.cpp - ${TENSOREXPR_TEST_ROOT}/test_ir_verifier.cpp - ${TENSOREXPR_TEST_ROOT}/test_kernel.cpp - ${TENSOREXPR_TEST_ROOT}/test_loopnest.cpp - ${TENSOREXPR_TEST_ROOT}/test_memdependency.cpp - ${TENSOREXPR_TEST_ROOT}/test_ops.cpp - ${TENSOREXPR_TEST_ROOT}/test_quantization.cpp - ${TENSOREXPR_TEST_ROOT}/test_memplanning.cpp - ${TENSOREXPR_TEST_ROOT}/test_reductions.cpp - ${TENSOREXPR_TEST_ROOT}/test_registerizer.cpp - ${TENSOREXPR_TEST_ROOT}/test_simplify.cpp - ${TENSOREXPR_TEST_ROOT}/test_te_fuser_pass.cpp - ${TENSOREXPR_TEST_ROOT}/test_type.cpp - ${TENSOREXPR_TEST_ROOT}/test_type_specializations.cpp -) - -if(USE_CUDA) - list(APPEND TENSOREXPR_TEST_SRCS ${TENSOREXPR_TEST_ROOT}/test_cuda.cpp) -endif() - -if(USE_LLVM AND LLVM_FOUND) - list(APPEND TENSOREXPR_TEST_SRCS ${TENSOREXPR_TEST_ROOT}/test_llvm.cpp) -endif() - -add_executable(test_tensorexpr - ${TORCH_ROOT}/test/cpp/common/main.cpp - ${TENSOREXPR_TEST_ROOT}/padded_buffer.cpp - ${TENSOREXPR_TEST_SRCS}) - -target_link_libraries(test_tensorexpr PRIVATE torch gtest_main) -target_include_directories(test_tensorexpr PRIVATE ${ATen_CPU_INCLUDE}) -target_compile_definitions(test_tensorexpr PRIVATE USE_GTEST) - -add_executable(tutorial_tensorexpr ${TENSOREXPR_TEST_ROOT}/tutorial.cpp) -target_link_libraries(tutorial_tensorexpr PRIVATE torch) -target_include_directories(tutorial_tensorexpr PRIVATE ${ATen_CPU_INCLUDE}) - -# The test case depends on the xnnpack header which in turn depends on the -# pthreadpool header. For some build environment we need add the dependency -# explicitly. -if(USE_PTHREADPOOL) - target_link_libraries(test_tensorexpr PRIVATE pthreadpool_interface) -endif() -if(USE_CUDA) - target_compile_definitions(test_tensorexpr PRIVATE USE_CUDA) - target_compile_definitions(tutorial_tensorexpr PRIVATE USE_CUDA) -elseif(USE_ROCM) - target_link_libraries(test_tensorexpr PRIVATE - hiprtc::hiprtc - hip::amdhip64 - ${TORCH_CUDA_LIBRARIES}) - target_compile_definitions(test_tensorexpr PRIVATE USE_ROCM) - - target_link_libraries(tutorial_tensorexpr PRIVATE - hiprtc::hiprtc - hip::amdhip64 - ${TORCH_CUDA_LIBRARIES}) - target_compile_definitions(tutorial_tensorexpr PRIVATE USE_ROCM) -endif() - -if(INSTALL_TEST) - set_target_properties(test_tensorexpr PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") - install(TARGETS test_tensorexpr DESTINATION bin) - set_target_properties(tutorial_tensorexpr PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") - install(TARGETS tutorial_tensorexpr DESTINATION bin) - # Install PDB files for MSVC builds - if(MSVC AND BUILD_SHARED_LIBS) - install(FILES $ DESTINATION bin OPTIONAL) - install(FILES $ DESTINATION bin OPTIONAL) - endif() -endif() diff --git a/test/cpp/tensorexpr/README.md b/test/cpp/tensorexpr/README.md deleted file mode 100644 index f86a50a65e80..000000000000 --- a/test/cpp/tensorexpr/README.md +++ /dev/null @@ -1,55 +0,0 @@ -# TensorExpr C++ Tests - -## How to add a new test -First, create a new test file. Test files should have be placed in this -directory, with a name that starts with `test_`, like `test_foo.cpp`. - -Here is an example test file you can copy-paste. -```cpp -#include - -// Tests go in torch::jit -namespace torch { -namespace jit { - -// 1. Test cases are void() functions. -// 2. They start with the prefix `test` -void testCaseOne() { - // ... -} - -void testCaseTwo() { - // ... -} -} -} -``` - -Then, register your test in `tests.h`: -```cpp -// Add to TH_FORALL_TESTS_CUDA instead for CUDA-requiring tests -#define TH_FORALL_TESTS(_) \ - _(ADFormulas) \ - _(Attributes) \ - ... - _(CaseOne) // note that the `test` prefix is omitted. - _(CaseTwo) -``` - -We glob all the test files together in `CMakeLists.txt` so that you don't -have to edit it every time you add a test. Unfortunately, this means that in -order to get the build to pick up your new test file, you need to re-run -cmake: -```bash -CMAKE_FRESH=1 python setup.py build -``` - -## How do I run the tests? -The following commands assume you are in PyTorch root. - - ```bash - # (re)build the test binary - ninja build/bin/test_tensorexpr - # run - build/bin/test_tensorexpr --gtest_filter='glob_style_filter*' - ``` diff --git a/test/cpp/tensorexpr/gtest_assert_float_eq.h b/test/cpp/tensorexpr/gtest_assert_float_eq.h deleted file mode 100644 index f85264a8f5d3..000000000000 --- a/test/cpp/tensorexpr/gtest_assert_float_eq.h +++ /dev/null @@ -1,119 +0,0 @@ -#pragma once - -#include -// Copyright 2005, Google Inc. -// All rights reserved. -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are -// met: -// -// * Redistributions of source code must retain the above copyright -// notice, this list of conditions and the following disclaimer. -// * Redistributions in binary form must reproduce the above -// copyright notice, this list of conditions and the following disclaimer -// in the documentation and/or other materials provided with the -// distribution. -// * Neither the name of Google Inc. nor the names of its -// contributors may be used to endorse or promote products derived from -// this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -// -// The Google C++ Testing and Mocking Framework (Google Test) -// -// This header file declares functions and macros used internally by -// Google Test. They are subject to change without notice. - -using Bits = uint32_t; - -// this avoids the "dereferencing type-punned pointer -// will break strict-aliasing rules" error -union Float { - float float_; - Bits bits_; -}; - -// # of bits in a number. -static const size_t kBitCount = 8 * sizeof(Bits); -// The mask for the sign bit. -static const Bits kSignBitMask = static_cast(1) << (kBitCount - 1); - -// GOOGLETEST_CM0001 DO NOT DELETE - -// Converts an integer from the sign-and-magnitude representation to -// the biased representation. More precisely, let N be 2 to the -// power of (kBitCount - 1), an integer x is represented by the -// unsigned number x + N. -// -// For instance, -// -// -N + 1 (the most negative number representable using -// sign-and-magnitude) is represented by 1; -// 0 is represented by N; and -// N - 1 (the biggest number representable using -// sign-and-magnitude) is represented by 2N - 1. -// -// Read http://en.wikipedia.org/wiki/Signed_number_representations -// for more details on signed number representations. -static Bits SignAndMagnitudeToBiased(const Bits& sam) { - if (kSignBitMask & sam) { - // sam represents a negative number. - return ~sam + 1; - } else { - // sam represents a positive number. - return kSignBitMask | sam; - } -} - -// Given two numbers in the sign-and-magnitude representation, -// returns the distance between them as an unsigned number. -static Bits DistanceBetweenSignAndMagnitudeNumbers( - const Bits& sam1, - const Bits& sam2) { - const Bits biased1 = SignAndMagnitudeToBiased(sam1); - const Bits biased2 = SignAndMagnitudeToBiased(sam2); - return (biased1 >= biased2) ? (biased1 - biased2) : (biased2 - biased1); -} - -// How many ULP's (Units in the Last Place) we want to tolerate when -// comparing two numbers. The larger the value, the more error we -// allow. A 0 value means that two numbers must be exactly the same -// to be considered equal. -// -// The maximum error of a single floating-point operation is 0.5 -// units in the last place. On Intel CPU's, all floating-point -// calculations are done with 80-bit precision, while double has 64 -// bits. Therefore, 4 should be enough for ordinary use. -// -// See the following article for more details on ULP: -// http://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/ -static const size_t kMaxUlps = 4; - -// Returns true if and only if this number is at most kMaxUlps ULP's away -// from rhs. In particular, this function: -// -// - returns false if either number is (or both are) NAN. -// - treats really large numbers as almost equal to infinity. -// - thinks +0.0 and -0.0 are 0 DLP's apart. -inline bool AlmostEquals(float lhs, float rhs) { - // The IEEE standard says that any comparison operation involving - // a NAN must return false. - if (std::isnan(lhs) || std::isnan(rhs)) - return false; - - Float l = {lhs}; - Float r = {rhs}; - - return DistanceBetweenSignAndMagnitudeNumbers(l.bits_, r.bits_) <= kMaxUlps; -} diff --git a/test/cpp/tensorexpr/padded_buffer.cpp b/test/cpp/tensorexpr/padded_buffer.cpp deleted file mode 100644 index 424d82c77453..000000000000 --- a/test/cpp/tensorexpr/padded_buffer.cpp +++ /dev/null @@ -1,37 +0,0 @@ -#include "test/cpp/tensorexpr/padded_buffer.h" - -#include -#include -#include - -namespace torch { -namespace jit { -namespace tensorexpr { - -int PaddedBufferBase::Index(const std::vector& indices) const { - TORCH_DCHECK_EQ(dims_.size(), indices.size()); - int total_index = 0; - for (const auto i : c10::irange(dims_.size())) { - total_index += indices[i] * strides_[i]; - } - return total_index; -} - -PaddedBufferBase::PaddedBufferBase( - const std::vector& dims, - // NOLINTNEXTLINE(modernize-pass-by-value) - const std::string& name) - : dims_(dims), name_(name), strides_(dims.size()) { - for (int i = (int)dims.size() - 1; i >= 0; --i) { - if (i == (int)dims.size() - 1) { - strides_[i] = 1; - } else { - strides_[i] = strides_[i + 1] * dims[i + 1]; - } - } - total_size_ = strides_[0] * dims[0]; -} - -} // namespace tensorexpr -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/padded_buffer.h b/test/cpp/tensorexpr/padded_buffer.h deleted file mode 100644 index b3e5227ae7e6..000000000000 --- a/test/cpp/tensorexpr/padded_buffer.h +++ /dev/null @@ -1,242 +0,0 @@ -#pragma once - -#include -#include - -#include -#include "torch/csrc/jit/tensorexpr/eval.h" - -namespace torch { -namespace jit { -namespace tensorexpr { - -template -struct DefaultPaddedValue; - -template <> -struct DefaultPaddedValue { - static const int kValue = static_cast(0xDEADBEEF); -}; - -template <> -struct DefaultPaddedValue { - static const int8_t kValue = static_cast(0xBE); -}; - -template <> -struct DefaultPaddedValue { - static const uint8_t kValue = static_cast(0xBE); -}; - -template <> -struct DefaultPaddedValue { - static const int16_t kValue = static_cast(0xBEEF); -}; - -template <> -struct DefaultPaddedValue { - static const int64_t kValue = static_cast(0xDEADBEEF); -}; - -template <> -struct DefaultPaddedValue { - static constexpr float kValue = 0.1357; -}; - -template <> -struct DefaultPaddedValue { - // at::Half ctor isn't constexpr, so just fill it with bits. - static constexpr uint16_t kValue = 1357; -}; - -template <> -struct DefaultPaddedValue { - static constexpr double kValue = 0.1357; -}; - -// A concrete base to be used in PaddedBase. -class PaddedBufferBase { - public: - const std::string& name() const { - return name_; - } - - int size() const { - return total_size_; - } - - int raw_size() const { - return total_size_ + 2 * kPaddingSize; - } - - virtual ~PaddedBufferBase() {} - - protected: - explicit PaddedBufferBase( - const std::vector& dims, - const std::string& name); - int Index(const std::vector& indices) const; - - std::vector dims_; - std::string name_; - std::vector strides_; - int total_size_; // total number of useful element, does not include the - // paddings - static constexpr int kPaddingSize = 64; -}; - -// A padded buffer with wartermarks for testing. -// The buffer carries padded watermarks on both sides to catch potential -// out-of-bounds writes. For read-only data that are not supposed to change, it -// can also make a backup and be compared later. -template -class PaddedBuffer : public PaddedBufferBase { - public: - PaddedBuffer(int d0, const std::string& name = "") - : PaddedBuffer(std::vector({d0}), name) {} - PaddedBuffer(int d0, int d1, const std::string& name = "") - : PaddedBuffer(std::vector({d0, d1}), name) {} - PaddedBuffer(int d0, int d1, int d2, const std::string& name = "") - : PaddedBuffer(std::vector({d0, d1, d2}), name) {} - PaddedBuffer(int d0, int d1, int d2, int d3, const std::string& name = "") - : PaddedBuffer(std::vector({d0, d1, d2, d3}), name) {} - PaddedBuffer(const std::vector& dims, const std::string& name = "") - : PaddedBufferBase(dims, name) { - data_.resize(total_size_ + 2 * kPaddingSize, kPaddingValue); - } - PaddedBuffer(const PaddedBuffer& other, const std::string& name) - : PaddedBuffer(other) { - this->name_ = name; - } - - T* data() { - return data_.data() + kPaddingSize; - } - const T* data() const { - return const_cast(this)->data(); - } - T* raw_data() { - return data_.data(); - } - const T* raw_data() const { - return const_cast(this)->raw_data(); - } - T& operator()(int i0) { - // There is a bit performance impact with forming a vector here. But this - // data structure is for testing only, and not performance critical. - return this->operator()(std::vector({i0})); - } - const T& operator()(int i0) const { - return const_cast(this)->operator()(i0); - } - T& operator()(int i0, int i1) { - return this->operator()(std::vector({i0, i1})); - } - const T& operator()(int i0, int i1) const { - return const_cast(this)->operator()(i0, i1); - } - T& operator()(int i0, int i1, int i2) { - return this->operator()(std::vector({i0, i1, i2})); - } - const T& operator()(int i0, int i1, int i2) const { - return const_cast(this)->operator()(i0, i1, i2); - } - T& operator()(int i0, int i1, int i2, int i3) { - return this->operator()(std::vector({i0, i1, i2, i3})); - } - const T& operator()(int i0, int i1, int i2, int i3) const { - return const_cast(this)->operator()(i0, i1, i2, i3); - } - T& operator()(const std::vector& indices) { - return data_[kPaddingSize + Index(indices)]; - } - const T& operator()(const std::vector& indices) const { - return const_cast(this)->operator()(indices); - } - - template - friend void ExpectAllNear( - const PaddedBuffer& v1, - const PaddedBuffer& v2, - float abs_error); - template - friend void ExpectAllEqual( - const PaddedBuffer& v1, - const PaddedBuffer& v2); - void Backup() { - backup_data_ = data_; - } - - // Verify the watermarks in the paddings are intact. - void ValidateWatermark() const { - for (const auto i : c10::irange(kPaddingSize)) { - ASSERT_EQ(data_[i], kPaddingValue); - ASSERT_EQ(data_[i + total_size_ + kPaddingSize], kPaddingValue); - } - } - - void CheckBackup() const { - ValidateWatermark(); - DCHECK(backup_data_.size() == data_.size()) - << "Please make sure you have call Backup() before calling CheckBackup()"; - for (const auto i : c10::irange(total_size_)) { - ASSERT_EQ(data_[i + kPaddingSize], backup_data_[i + kPaddingSize]); - } - } - - private: - std::vector data_; - std::vector backup_data_; - T kPaddingValue = DefaultPaddedValue::kValue; -}; - -template -inline CodeGen::CallArg::CallArg(const PaddedBuffer& buffer) - : data_(const_cast(buffer.data())) {} - -template -std::string CompareErrorMsg( - const PaddedBuffer& v1, - const PaddedBuffer& v2, - int index) { - std::ostringstream oss; - oss << "index: " << index << ", v1: (" << v1.name() << ", " << v1(index) - << ")" - << ", v2: (" << v2.name() << ", " << v2(index) << ")"; - return oss.str(); -} - -template -void ExpectAllEqual(const PaddedBuffer& f1, const PaddedBuffer& f2) { - const std::vector& v1 = f1.data_; - const std::vector& v2 = f2.data_; - const int kPaddingSize = f1.kPaddingSize; - const int total_size = f1.total_size_; - ASSERT_EQ(v1.size(), v2.size()); - f1.ValidateWatermark(); - f2.ValidateWatermark(); - for (const auto i : c10::irange(total_size)) { - ASSERT_EQ(v1[kPaddingSize + i], v2[kPaddingSize + i]); - } -} - -template -void ExpectAllNear( - const PaddedBuffer& f1, - const PaddedBuffer& f2, - float abs_error) { - const std::vector& v1 = f1.data_; - const std::vector& v2 = f2.data_; - const int kPaddingSize = f1.kPaddingSize; - const int total_size = f1.total_size_; - ASSERT_EQ(v1.size(), v2.size()); - f1.ValidateWatermark(); - f2.ValidateWatermark(); - for (const auto i : c10::irange(total_size)) { - ASSERT_NEAR(v1[kPaddingSize + i], v2[kPaddingSize + i], abs_error); - } -} - -} // namespace tensorexpr -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_approx.cpp b/test/cpp/tensorexpr/test_approx.cpp deleted file mode 100644 index e1a576aecf52..000000000000 --- a/test/cpp/tensorexpr/test_approx.cpp +++ /dev/null @@ -1,96 +0,0 @@ -#ifdef TORCH_ENABLE_LLVM - -#include -#include -#include -#include -#include -#include -#include - -using namespace torch::indexing; -namespace te = torch::jit::tensorexpr; - -static void vectorize(te::LoopNest* ln, te::Tensor target, int width) { - auto loops = ln->getLoopStmtsFor(target); - te::ForPtr inner, tail; - ln->splitWithTail(loops[0], width, &inner, &tail); - ASSERT_TRUE(te::LoopNest::vectorize(inner)); -} - -std::string diffs(const at::Tensor& a, const at::Tensor& b) { - auto diff = torch::abs(a.flatten() - b.flatten()); - auto count_diffs = torch::sum(diff > 0.f); - auto greatest_diff_index = torch::argmax(diff); - std::stringstream ss; - ss << "Found " << count_diffs << " unequal element(s). " - << "The greatest difference was " << diff.index({greatest_diff_index}) - << " at index " << greatest_diff_index; - return ss.str(); -} - -TEST(Approx, log_vml) { - te::VarHandle N("N", te::kInt); - te::BufHandle A("A", {N}, te::kFloat); - te::Tensor B = te::Compute( - "B", {N}, [&](const te::VarHandle& i) { return log_vml(A.load(i)); }); - - te::LoopNest ln({B}); - ln.prepareForCodegen(); - vectorize(&ln, B, 8); - te::StmtPtr s = ln.root_stmt(); - s = te::IRSimplifier::simplify(s); - te::LLVMCodeGen cg(s, {A, B, N}); - - auto eps = std::numeric_limits::epsilon(); - auto test = [&](const at::Tensor& A_t) { - at::Tensor B_ref = at::log(A_t); - at::Tensor B_t = at::empty_like(A_t); - auto ap = A_t.data_ptr(); - auto bp = B_t.data_ptr(); - cg.call({ap, bp, A_t.numel()}); - // Results should be bit-identical. - ASSERT_TRUE(torch::allclose( - B_t, B_ref, /*rtol=*/eps, /*atol=*/0.0f, /*equal_nan=*/true)) - << "Input[:8]\n" - << A_t.index({Slice(0, 8)}) << "\n" - << "Test[:8]\n" - << B_t.index({Slice(0, 8)}) << "\n" - << "Ref[:8]\n" - << B_ref.index({Slice(0, 8)}) << diffs(B_t, B_ref); - }; - - // Generate every single-precision FP value in [1.0, 2.0). - at::Tensor A_t = torch::arange(1.0f, 2.0f, eps); - ASSERT_EQ(A_t.numel(), 1 << 23); - - test(A_t); - - test(A_t * 2.0f); - test(A_t * 0.5f); - - test(A_t * 4.0f); - test(A_t * 0.25f); - - test(A_t * powf(2.0f, 16)); - test(A_t * powf(2.0f, -16)); - - test(A_t * powf(2.0f, 126)); - test(A_t * powf(2.0f, -126)); - - test(torch::full({32}, INFINITY)); - test(torch::full({32}, NAN)); - - auto min = std::numeric_limits::min(); - auto denorm_min = std::numeric_limits::denorm_min(); - - // Denormals aren't bit precise, because sleef isn't bit-precise either. - A_t = torch::arange(0.0f, min, denorm_min); - ASSERT_EQ(A_t.numel(), 1 << 23); - auto B_ref = at::log(A_t); - auto B_t = at::empty_like(B_ref); - cg.call({A_t.data_ptr(), B_t.data_ptr(), A_t.numel()}); - ASSERT_TRUE(torch::allclose(B_t, B_ref)); -} - -#endif // TORCH_ENABLE_LLVM diff --git a/test/cpp/tensorexpr/test_aten.cpp b/test/cpp/tensorexpr/test_aten.cpp deleted file mode 100644 index 34ce2bd069d5..000000000000 --- a/test/cpp/tensorexpr/test_aten.cpp +++ /dev/null @@ -1,1068 +0,0 @@ -#include -#include -#include - -#include - -#include -#include -#include "test/cpp/tensorexpr/padded_buffer.h" -#include "test/cpp/tensorexpr/test_base.h" -#include "torch/csrc/jit/tensorexpr/ir_printer.h" - -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -TEST(ATen, _cast_Float) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle to_float = Cast::make(kFloat, load_a); - StmtPtr store_b = b_buf.store({index}, to_float); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), static_cast(i)); - } -} - -TEST(ATen, negInt) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle to_float = Sub::make(0, load_a); - StmtPtr store_b = b_buf.store({index}, to_float); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), -static_cast(i)); - } -} - -TEST(ATen, negFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle to_float = Sub::make(0, load_a); - StmtPtr store_b = b_buf.store({index}, to_float); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), -i); - } -} - -TEST(ATen, addInt) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); - BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kInt); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - ExprHandle load_c = c_buf.load(index); - StmtPtr store_d = d_buf.store({index}, load_a + load_b * load_c); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_d); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - PaddedBuffer d_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - c_v(i) = 3 * i + 2; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf}); - ir_eval(a_v, b_v, c_v, d_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), 3 * i + 2); - ASSERT_EQ(d_v(i), a_v(i) + b_v(i) * c_v(i)); - } -} - -TEST(ATen, addFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); - BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - ExprHandle load_c = c_buf.load(index); - StmtPtr store_d = d_buf.store({index}, load_a + load_b * load_c); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_d); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - PaddedBuffer d_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - c_v(i) = 3 * i + 2; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf}); - ir_eval(a_v, b_v, c_v, d_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), 3 * i + 2); - ASSERT_EQ(d_v(i), a_v(i) + b_v(i) * c_v(i)); - } -} - -TEST(ATen, subInt) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); - BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kInt); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - ExprHandle load_c = c_buf.load(index); - StmtPtr store_d = d_buf.store({index}, load_a - load_b * load_c); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_d); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - PaddedBuffer d_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - c_v(i) = 3 * i + 2; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf}); - ir_eval(a_v, b_v, c_v, d_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), 3 * i + 2); - ASSERT_EQ(d_v(i), a_v(i) - b_v(i) * c_v(i)); - } -} - -TEST(ATen, subFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); - BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - ExprHandle load_c = c_buf.load(index); - StmtPtr store_d = d_buf.store({index}, load_a - load_b * load_c); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_d); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - PaddedBuffer d_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - c_v(i) = 3 * i + 2; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf}); - ir_eval(a_v, b_v, c_v, d_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), 3 * i + 2); - ASSERT_EQ(d_v(i), a_v(i) - b_v(i) * c_v(i)); - } -} - -TEST(ATen, lerp) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); - BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - ExprHandle load_c = c_buf.load(index); - StmtPtr store_d = d_buf.store({index}, load_a + load_c * (load_b - load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_d); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - PaddedBuffer d_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - c_v(i) = 3 * i + 2; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf}); - ir_eval(a_v, b_v, c_v, d_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), 3 * i + 2); - ASSERT_EQ(d_v(i), a_v(i) + c_v(i) * (b_v(i) - a_v(i))); - } -} - -TEST(ATen, addcmulInt) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); - BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kInt); - BufHandle e_buf("E", {ExprHandle(kTotalSize)}, kInt); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - ExprHandle load_c = c_buf.load(index); - ExprHandle load_d = d_buf.load(index); - StmtPtr store_e = e_buf.store({index}, load_a + load_b * load_c * load_d); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_e); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - PaddedBuffer d_v(kTotalSize); - PaddedBuffer e_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - c_v(i) = 3 * i + 2; - d_v(i) = 5 * i + 3; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf, e_buf}); - ir_eval(a_v, b_v, c_v, d_v, e_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), 3 * i + 2); - ASSERT_EQ(d_v(i), 5 * i + 3); - ASSERT_EQ(e_v(i), a_v(i) + b_v(i) * c_v(i) * d_v(i)); - } -} - -TEST(ATen, addcmulFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); - BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kFloat); - BufHandle e_buf("E", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - ExprHandle load_c = c_buf.load(index); - ExprHandle load_d = d_buf.load(index); - StmtPtr store_e = e_buf.store({index}, load_a + load_b * load_c * load_d); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_e); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - PaddedBuffer d_v(kTotalSize); - PaddedBuffer e_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - c_v(i) = 3 * i + 2; - d_v(i) = 5 * i + 3; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf, e_buf}); - ir_eval(a_v, b_v, c_v, d_v, e_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), 3 * i + 2); - ASSERT_EQ(d_v(i), 5 * i + 3); - ASSERT_FLOAT_EQ(e_v(i), a_v(i) + b_v(i) * c_v(i) * d_v(i)); - } -} - -TEST(ATen, mulInt) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - StmtPtr store_c = c_buf.store({index}, load_a * load_b); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); - ir_eval(a_v, b_v, c_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), a_v(i) * b_v(i)); - } -} - -TEST(ATen, mulFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - StmtPtr store_c = c_buf.store({index}, load_a * load_b); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); - ir_eval(a_v, b_v, c_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), a_v(i) * b_v(i)); - } -} - -TEST(ATen, divInt) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - StmtPtr store_c = c_buf.store({index}, load_a / load_b); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = 2 * i + 1; - b_v(i) = i + 1; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); - ir_eval(a_v, b_v, c_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), 2 * i + 1); - ASSERT_EQ(b_v(i), i + 1); - ASSERT_EQ(c_v(i), a_v(i) / b_v(i)); - } -} - -TEST(ATen, divFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - StmtPtr store_c = c_buf.store({index}, load_a / load_b); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = 2 * i + 1; - b_v(i) = i + 1; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); - ir_eval(a_v, b_v, c_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), 2 * i + 1); - ASSERT_EQ(b_v(i), i + 1); - ASSERT_EQ(c_v(i), a_v(i) / b_v(i)); - } -} - -TEST(ATen, maxInt) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - StmtPtr store_c = c_buf.store({index}, Max::make(load_a, load_b, true)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); - ir_eval(a_v, b_v, c_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), std::max(a_v(i), b_v(i))); - } -} - -TEST(ATen, maxFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - StmtPtr store_c = c_buf.store({index}, Max::make(load_a, load_b, true)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); - ir_eval(a_v, b_v, c_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), std::fmax(a_v(i), b_v(i))); - } -} - -TEST(ATen, minInt) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - StmtPtr store_c = c_buf.store({index}, Min::make(load_a, load_b, true)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); - ir_eval(a_v, b_v, c_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), std::min(a_v(i), b_v(i))); - } -} - -TEST(ATen, minFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - ExprHandle load_b = b_buf.load(index); - StmtPtr store_c = c_buf.store({index}, Min::make(load_a, load_b, true)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - b_v(i) = 2 * i + 1; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); - ir_eval(a_v, b_v, c_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 2 * i + 1); - ASSERT_EQ(c_v(i), std::fmin(a_v(i), b_v(i))); - } -} - -void __ubsan_ignore_float_divide_by_zero__ testATenreciprocal() { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, FloatImm::make(1.0f) / load_a); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i); - ASSERT_EQ(b_v(i), 1.0f / i); - } -} - -TEST(ATen, reluInt) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, Max::make(load_a, 0, false)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i - 64; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i - 64); - ASSERT_EQ(b_v(i), std::max(a_v(i), 0)); - } -} - -TEST(ATen, reluFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store( - {index}, Max::make(load_a, 0, false) // relu does not propagate nans - ); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i - 64; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i - 64); - ASSERT_EQ(b_v(i), std::fmax(a_v(i), 0)); - } -} - -TEST(ATen, logFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, log(load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i + 10; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i + 10); - ASSERT_EQ(b_v(i), std::log(a_v(i))); - } -} - -TEST(ATen, fastLogFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, fast_log(load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = at::randn({1}).item().to(); - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - auto test = b_v(i); - auto ref = std::log(a_v(i)); - if (std::isnan(ref)) { - ASSERT_EQ(std::isnan(test), true); - } else { - ASSERT_FLOAT_EQ(test, ref); - } - } -} - -TEST(ATen, fastTanhFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, fast_tanh(load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = at::randn({1}).item().to(); - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - auto test = b_v(i); - auto ref = std::tanh(a_v(i)); - if (std::isnan(ref)) { - ASSERT_EQ(std::isnan(test), true); - } else { - ASSERT_NEAR(test, ref, 1e-6); - } - } -} - -TEST(ATen, fastSigmoidFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, fast_sigmoid(load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = at::randn({1}).item().to(); - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - auto test = b_v(i); - at::Tensor t = at::ones({1}) * a_v(i); - float ref = at::sigmoid(t).item().to(); - if (std::isnan(ref)) { - ASSERT_EQ(std::isnan(test), true); - } else { - ASSERT_NEAR(test, ref, 1e-6); - } - } -} - -TEST(ATen, log10Float) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, log10(load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i + 10; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i + 10); - ASSERT_EQ(b_v(i), std::log10(a_v(i))); - } -} - -TEST(ATen, log2Float) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, log2(load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i + 10; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i + 10); - ASSERT_EQ(b_v(i), std::log2(a_v(i))); - } -} - -TEST(ATen, expFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, exp(load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) - a_v(i) = i / 10.0f; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i / 10.0f); - ASSERT_EQ(b_v(i), std::exp(a_v(i))); - } -} - -TEST(ATen, erfFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, erf(load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) - a_v(i) = i / 10.0f; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i / 10.0f); - ASSERT_EQ(b_v(i), std::erf(a_v(i))); - } -} - -TEST(ATen, cosFloat) { - const int kTotalSize = 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, cos(load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) - a_v(i) = i / 10.0f; - } - - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); - ir_eval(a_v, b_v); - - for (const auto i : c10::irange(kTotalSize)) { - ASSERT_EQ(a_v(i), i / 10.0f); - ASSERT_EQ(b_v(i), std::cos(a_v(i))); - } -} - -TEST(ATen, eqInt) { - constexpr int N = 128; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 1); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 0); - - VarHandle i("i", kInt); - auto memcpy_expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kEQ))); - - SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); - ir_eval(a_buffer, b_buffer, c_buffer); - - assertAllEqual(c_buffer, 1); -} - -TEST(ATen, geInt) { - constexpr int N = 128; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 5); - std::vector b_buffer(N, 5); - std::vector c_buffer(N, 0); - - VarHandle i("i", kInt); - auto memcpy_expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kGE))); - - SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); - ir_eval(a_buffer, b_buffer, c_buffer); - - assertAllEqual(c_buffer, 1); -} - -TEST(ATen, gtInt) { - constexpr int N = 128; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 6); - std::vector b_buffer(N, 3); - std::vector c_buffer(N, 0); - - VarHandle i("i", kInt); - auto memcpy_expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kGT))); - - SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); - ir_eval(a_buffer, b_buffer, c_buffer); - - assertAllEqual(c_buffer, 1); -} - -TEST(ATen, leInt) { - constexpr int N = 128; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 5); - std::vector b_buffer(N, 5); - std::vector c_buffer(N, 0); - - VarHandle i("i", kInt); - auto memcpy_expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kLE))); - - SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); - ir_eval(a_buffer, b_buffer, c_buffer); - - assertAllEqual(c_buffer, 1); -} - -TEST(ATen, ltInt) { - constexpr int N = 128; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 5); - std::vector b_buffer(N, 5); - std::vector c_buffer(N, 1); - - VarHandle i("i", kInt); - auto memcpy_expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kLT))); - - SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); - ir_eval(a_buffer, b_buffer, c_buffer); - - assertAllEqual(c_buffer, 0); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_base.h b/test/cpp/tensorexpr/test_base.h deleted file mode 100644 index 68b96fe6c90f..000000000000 --- a/test/cpp/tensorexpr/test_base.h +++ /dev/null @@ -1,89 +0,0 @@ -#pragma once - -#if defined(USE_GTEST) -#include -#include -#else -#include -#include "c10/util/Exception.h" -#include "test/cpp/tensorexpr/gtest_assert_float_eq.h" -#define ASSERT_EQ(x, y, ...) TORCH_INTERNAL_ASSERT((x) == (y), __VA_ARGS__) -#define ASSERT_FLOAT_EQ(x, y, ...) \ - TORCH_INTERNAL_ASSERT(AlmostEquals((x), (y)), __VA_ARGS__) -#define ASSERT_NE(x, y, ...) TORCH_INTERNAL_ASSERT((x) != (y), __VA_ARGS__) -#define ASSERT_GT(x, y, ...) TORCH_INTERNAL_ASSERT((x) > (y), __VA_ARGS__) -#define ASSERT_GE(x, y, ...) TORCH_INTERNAL_ASSERT((x) >= (y), __VA_ARGS__) -#define ASSERT_LT(x, y, ...) TORCH_INTERNAL_ASSERT((x) < (y), __VA_ARGS__) -#define ASSERT_LE(x, y, ...) TORCH_INTERNAL_ASSERT((x) <= (y), __VA_ARGS__) - -#define ASSERT_NEAR(x, y, a, ...) \ - TORCH_INTERNAL_ASSERT(std::fabs((x) - (y)) < (a), __VA_ARGS__) - -#define ASSERT_TRUE TORCH_INTERNAL_ASSERT -#define ASSERT_FALSE(x) ASSERT_TRUE(!(x)) -#define ASSERT_THROWS_WITH(statement, substring) \ - try { \ - (void)statement; \ - ASSERT_TRUE(false); \ - } catch (const std::exception& e) { \ - ASSERT_NE(std::string(e.what()).find(substring), std::string::npos); \ - } -#define ASSERT_ANY_THROW(statement) \ - { \ - bool threw = false; \ - try { \ - (void)statement; \ - } catch (const std::exception& e) { \ - threw = true; \ - } \ - ASSERT_TRUE(threw); \ - } - -#endif // defined(USE_GTEST) -#include -#include - -namespace torch { -namespace jit { -namespace tensorexpr { - -template -void ExpectAllNear( - const std::vector& v1, - const std::vector& v2, - V threshold, - const std::string& name = "") { - ASSERT_EQ(v1.size(), v2.size()); - for (size_t i = 0; i < v1.size(); i++) { - ASSERT_NEAR(v1[i], v2[i], threshold); - } -} - -template -void ExpectAllNear( - const std::vector& vec, - const U& val, - V threshold, - const std::string& name = "") { - for (size_t i = 0; i < vec.size(); i++) { - ASSERT_NEAR(vec[i], val, threshold); - } -} - -template -static void assertAllEqual(const std::vector& vec, const T& val) { - for (auto const& elt : vec) { - ASSERT_EQ(elt, val); - } -} - -template -static void assertAllEqual(const std::vector& v1, const std::vector& v2) { - ASSERT_EQ(v1.size(), v2.size()); - for (size_t i = 0; i < v1.size(); ++i) { - ASSERT_EQ(v1[i], v2[i]); - } -} -} // namespace tensorexpr -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_boundsinference.cpp b/test/cpp/tensorexpr/test_boundsinference.cpp deleted file mode 100644 index 2605842d6e74..000000000000 --- a/test/cpp/tensorexpr/test_boundsinference.cpp +++ /dev/null @@ -1,1019 +0,0 @@ -#include -#include -#include -#include - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -static void verifyConstBounds( - const TensorAccessBoundsInfo& access_info, - const std::vector>& ref) { - size_t ndim = ref.size(); - ASSERT_EQ(access_info.start.size(), ndim); - ASSERT_EQ(access_info.stop.size(), ndim); - for (const auto i : c10::irange(ndim)) { - if (ref[i].first >= 0) { // Negative values are used to skip the check - ASSERT_TRUE(access_info.start[i]->isConstant()); - int start_i = immediateAs(access_info.start[i]); - ASSERT_EQ(start_i, ref[i].first); - } - if (ref[i].second >= 0) { - ASSERT_TRUE(access_info.stop[i]->isConstant()); - int stop_i = immediateAs(access_info.stop[i]); - ASSERT_EQ(stop_i, ref[i].second); - } - } -} - -TEST(BoundsInference, _1) { - // Verify that bounds inference works for the following example: - // for i in 0..100: - // b[i] = a[i] - // For this loop bounds inference should yield the following: - // {{b, kStore, 0, 99}, {a, kLoad, 0, 99}} - ExprHandle n(100); - BufHandle a("a", {n}, kFloat); - Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); }); - LoopNest l({b}); - auto bounds_info = inferBounds(l.root_stmt()); - - // We should have two entries: one for 'b' and one for 'a'. - ASSERT_EQ(bounds_info.size(), 2); - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{0, 99}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 99}}); -} - -TEST(BoundsInference, _2) { - // Verify that bounds inference works for the following example: - // for i in 0..n: - // b[i] = a[i] - // For this loop bounds inference should yield the following: - // {{b, kStore, 0, n-1}, {a, kLoad, 0, n-1}} - VarHandle n("n", kInt); - BufHandle a("a", {n}, kFloat); - Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); }); - LoopNest l({b}); - auto bounds_info = inferBounds(l.root_stmt()); - - // We should have two entries: one for 'b' and one for 'a'. - ASSERT_EQ(bounds_info.size(), 2); - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{0, -1}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(b.buf())[0], {{0, -1}}); -} - -TEST(BoundsInference, _3) { - // Verify that bounds inference works for the following example: - // for i in 0..100: - // b[i] = a[i] * a[i+10] - // For this loop bounds inference should yield the following: - // {{b, kStore, 0, 99}, {a, kLoad, 0, 109}} - ExprHandle n(100); - BufHandle a("a", {n + 10}, kFloat); - Tensor b = Compute( - "b", {n}, [&](const VarHandle& i) { return a.load(i) * a.load(i + 10); }); - LoopNest l({b}); - auto bounds_info = inferBounds(l.root_stmt()); - - // We should have two entries: one for 'b' and one for 'a'. - ASSERT_EQ(bounds_info.size(), 2); - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{0, 109}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 99}}); -} - -TEST(BoundsInference, _4) { - // Verify that bounds inference works for the following example: - // - // for y in 0..200: - // for x in 0..320: - // b[y,x] = x*y - // for y in 0..200: - // for x in 0..320: - // c[y,x] = a[y,x] * b[y,x] - ExprHandle W(320); - ExprHandle H(200); - BufHandle a("a", {H, W}, kFloat); - Tensor b = Compute("b", {H, W}, [&](const VarHandle& y, const VarHandle& x) { - return x * y; - }); - Tensor c = Compute("c", {H, W}, [&](const VarHandle& y, const VarHandle& x) { - return a.load(y, x) * b.load(y, x); - }); - LoopNest l({c}); - std::vector loops = l.getLoopStmtsFor(c); - StmtPtr body = l.getLoopBodyFor(c); - { - // Infer bounds on the top-level loop scope - auto bounds_info = inferBounds(loops[0]); - ASSERT_EQ(bounds_info.size(), 3); - - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{0, 199}, {0, 319}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 199}, {0, 319}}); - - ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 199}, {0, 319}}); - } - { - // Infer bounds on the inner loop scope - auto bounds_info = inferBounds(loops[1]); - ASSERT_EQ(bounds_info.size(), 3); - - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{-1, -1}, {0, 319}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {0, 319}}); - - ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {0, 319}}); - } - { - // Infer bounds on the inner loop body's scope - auto bounds_info = inferBounds(body); - ASSERT_EQ(bounds_info.size(), 3); - - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{-1, -1}, {-1, -1}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {-1, -1}}); - - ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {-1, -1}}); - } -} - -TEST(BoundsInference, _5) { - // Verify that bounds inference works for the following example: - // for i in 0..100: - // b[i] = a[i] - // - // ==> split ==> - // - // for i_outer in 0..100/16: - // for i_inner in 0..16: - // b[i_outer * 16 + i_inner] = a[i_outer * 16 + i_inner] - // for i_tail in 0..100%16: - // b[i_tail + (100/16)*16] = a[i_tail + (100/16)*16]; - ExprHandle n(100); - BufHandle a("a", {n}, kFloat); - Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); }); - LoopNest l({b}); - - ForPtr inner; - ForPtr tail; - std::vector loops = l.getLoopStmtsFor(b); - LoopNest::splitWithTail(loops[0], 16, &inner, &tail); - ForPtr outer = loops[0]; - - { - // Verify inferred bounds for the outer loop - auto bounds_info = inferBounds(outer); - ASSERT_EQ(bounds_info.size(), 2); - - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{0, 95}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 95}}); - } - { - // Verify inferred bounds for the tail loop - auto bounds_info = inferBounds(tail); - ASSERT_EQ(bounds_info.size(), 2); - - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{96, 99}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(b.buf())[0], {{96, 99}}); - } -} - -TEST(BoundsInference, _6) { - // Verify that bounds inference works for the following example: - // - // for y in 0..200: - // for x in 0..320: - // b[y,x] = x*y - // for y in 0..20: - // for x in 0..32: - // c[y,x] = a[y+100,x+100] * b[y*2,x*5] - ExprHandle W(320); - ExprHandle H(200); - ExprHandle CW(32); - ExprHandle CH(20); - BufHandle a("a", {H, W}, kFloat); - Tensor b = Compute("b", {H, W}, [&](const VarHandle& y, const VarHandle& x) { - return x * y; - }); - Tensor c = - Compute("c", {CH, CW}, [&](const VarHandle& y, const VarHandle& x) { - return a.load(y + 100, x + 100) * b.load(y * 2, x * 5); - }); - LoopNest l({c}); - std::vector loops = l.getLoopStmtsFor(c); - StmtPtr body = l.getLoopBodyFor(c); - { - // Infer bounds on the top-level loop scope - auto bounds_info = inferBounds(loops[0]); - ASSERT_EQ(bounds_info.size(), 3); - - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{100, 119}, {100, 131}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 38}, {0, 155}}); - - ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 19}, {0, 31}}); - } - { - // Infer bounds on the inner loop scope - auto bounds_info = inferBounds(loops[1]); - ASSERT_EQ(bounds_info.size(), 3); - - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{-1, -1}, {100, 131}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {0, 155}}); - - ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {0, 31}}); - } - { - // Infer bounds on the inner loop body's scope - auto bounds_info = inferBounds(body); - ASSERT_EQ(bounds_info.size(), 3); - - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{-1, -1}, {-1, -1}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {-1, -1}}); - - ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {-1, -1}}); - } -} - -TEST(BoundsInference, Adjacent) { - ExprHandle H(6); - BufHandle a("a", {20}, kFloat); - Tensor b = Compute("b", {H}, [&](const VarHandle& x) { return a.load(x); }); - Tensor c = - Compute("c", {H}, [&](const VarHandle& x) { return a.load(x + H); }); - LoopNest l({b, c}); - std::vector loops = NodeFinder::find(l.root_stmt()); - - { - // Infer bounds on the top-level loop scope - auto bounds_info = inferBounds(loops[0]); - ASSERT_EQ(bounds_info.size(), 2); - - // reads from a[0:5], writes to b[0:5] - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{0, 5}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 5}}); - } - { - // Infer bounds on the inner loop scope - auto bounds_info = inferBounds(loops[1]); - ASSERT_EQ(bounds_info.size(), 2); - - // reads from a[0+6:5+6], writes to c[0:5] - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{6, 11}}); - - ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 5}}); - } - { - // Infer bounds on the high level program. - auto bounds_info = inferBounds(l.root_stmt()); - ASSERT_EQ(bounds_info.size(), 3); - - // Should be union of above 2 bounds, but this time the bounds of A can be - // merged. - ASSERT_EQ(bounds_info.at(a.node()).size(), 1); - ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(a.node())[0], {{0, 11}}); - - ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 5}}); - - ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 5}}); - } -} - -TEST(BoundsInference, MultipleTopLoopLoad) { - BufHandle a("a", {100}, kFloat); - Tensor b = Compute("b", {64}, [&](const VarHandle& x) { return a.load(x); }); - Tensor c = - Compute("c", {32}, [&](const VarHandle& x) { return a.load(x + 10); }); - Tensor d = - Compute("d", {96}, [&](const VarHandle& x) { return a.load(x + 2); }); - LoopNest l({b, c, d}); - - auto bounds_info = inferBounds(l.root_stmt()); - - ASSERT_EQ(bounds_info.size(), 4); - - // a only read. - { - auto bounds = bounds_info[a.node()]; - ASSERT_EQ(bounds.size(), 1); - // One dimension. - auto bound = bounds[0]; - ASSERT_EQ(bound.kind, TensorAccessKind::kLoad); - // Bounds: - // start: Min of the 3 load bounds = Min of loop starts + offset = 0+0 (b). - // stop: Max of the 3 load bounds = Max of loop stops + offset - 1 = - // 96 + 2 - 1 (d). - verifyConstBounds(bound, {{0, 97}}); - } - - // b, c, d only written. - { - auto bounds = bounds_info[b.buf()]; - ASSERT_EQ(bounds.size(), 1); - auto bound = bounds[0]; - ASSERT_EQ(bound.kind, TensorAccessKind::kStore); - // Just the loop extents for b. - verifyConstBounds(bound, {{0, 63}}); - } - { - auto bounds = bounds_info[c.buf()]; - ASSERT_EQ(bounds.size(), 1); - auto bound = bounds[0]; - ASSERT_EQ(bound.kind, TensorAccessKind::kStore); - // Just the loop extents for c. - verifyConstBounds(bound, {{0, 31}}); - } - { - auto bounds = bounds_info[d.buf()]; - ASSERT_EQ(bounds.size(), 1); - auto bound = bounds[0]; - ASSERT_EQ(bound.kind, TensorAccessKind::kStore); - // Just the loop extents for d. - verifyConstBounds(bound, {{0, 95}}); - } -} - -TEST(BoundsInference, MultipleTopLoopStore) { - BufHandle a("a", {100}, kFloat); - BufHandle b("b", {100}, kFloat); - BufHandle c("c", {100}, kFloat); - BufHandle d("d", {100}, kFloat); - VarHandle x("x", kInt); - - // Same as above but the offsets are on the Store now. - // Can't do this through ComputeAPI without transforms we don't have yet. - StmtPtr stmt = Block::make( - {For::make(x, 0, 64, Store::make(b, {x}, Load::make(a, {x}))), - For::make(x, 0, 32, Store::make(c, {x + 10}, Load::make(a, {x}))), - For::make(x, 0, 96, Store::make(d, {x + 2}, Load::make(a, {x})))}); - - auto bounds_info = inferBounds(stmt); - - ASSERT_EQ(bounds_info.size(), 4); - - // a only read. - { - auto bounds = bounds_info[a.node()]; - ASSERT_EQ(bounds.size(), 1); - // One dimension. - auto bound = bounds[0]; - ASSERT_EQ(bound.kind, TensorAccessKind::kLoad); - // Bounds: there are no offsets, so this is just the max loop bounds. - verifyConstBounds(bound, {{0, 95}}); - } - - // b, c, d only written. - { - auto bounds = bounds_info[b.node()]; - ASSERT_EQ(bounds.size(), 1); - auto bound = bounds[0]; - ASSERT_EQ(bound.kind, TensorAccessKind::kStore); - // This should be equivalent to {offset, extent + offset} for the b loop. - // b loop has no offset, so just the loop extents. - verifyConstBounds(bound, {{0, 63}}); - } - { - auto bounds = bounds_info[c.node()]; - ASSERT_EQ(bounds.size(), 1); - auto bound = bounds[0]; - ASSERT_EQ(bound.kind, TensorAccessKind::kStore); - // This should be equivalent to {offset, extent + offset} for the c loop. - // Offset is 10, extent is 32-1. - verifyConstBounds(bound, {{10, 41}}); - } - { - auto bounds = bounds_info[d.node()]; - ASSERT_EQ(bounds.size(), 1); - auto bound = bounds[0]; - ASSERT_EQ(bound.kind, TensorAccessKind::kStore); - // This should be equivalent to {offset, extent + offset} for the d loop. - // Offset is 2, extent is 96-1. - verifyConstBounds(bound, {{2, 97}}); - } -} - -TEST(BoundsInference, CacheReads) { - Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { - return i * j; - }); - Tensor B = - Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 30, j + 3); - }); - Tensor C = - Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); - }); - - LoopNest l({B, C}); - auto bounds_info_before = inferBounds(l.root_stmt()); - - StmtPtr j_loop = l.getLoopStmtsFor(B)[1]; - LoopNest::cacheAccesses(A.buf(), "A_local", j_loop); - - auto bounds_info_after = inferBounds(l.root_stmt()); - - // CacheAccesses should not change existing bounds, but add a new one for the - // cache. - for (auto& pair : bounds_info_after) { - auto beforeIt = bounds_info_before.find(pair.first); - if (beforeIt != bounds_info_before.end()) { - // Same number of TensorAccessBoundInfos. - ASSERT_EQ(pair.second.size(), beforeIt->second.size()); - - for (const auto i : c10::irange(pair.second.size())) { - TensorAccessBoundsInfo& after = pair.second[i]; - TensorAccessBoundsInfo& before = beforeIt->second[i]; - // Same number of dimensions. - ASSERT_EQ(before.start.size(), after.start.size()); - - // Bounds are equal. - for (const auto j : c10::irange(before.start.size())) { - ASSERT_TRUE(exprEquals(before.start[j], after.start[j])); - ASSERT_TRUE(exprEquals(before.stop[j], after.stop[j])); - } - } - } else { - // This should be the cache. - ASSERT_EQ(pair.first->name_hint(), "A_local"); - // Should have both a load and a store. - ASSERT_EQ(pair.second.size(), 2); - TensorAccessBoundsInfo& first = pair.second[0]; - TensorAccessBoundsInfo& second = pair.second[1]; - - ASSERT_NE(first.kind, second.kind); - // 2 dimensions. - ASSERT_EQ(first.start.size(), second.start.size()); - ASSERT_EQ(first.start.size(), 2); - - // bounds for load and store are equal. - for (const auto j : c10::irange(first.start.size())) { - ASSERT_TRUE(exprEquals(first.start[j], second.start[j])); - ASSERT_TRUE(exprEquals(first.stop[j], second.stop[j])); - } - } - } -} - -TEST(BoundsInference, Flattened) { - Tensor b = Compute( - "b", - {3, 4, 5}, - [&](const VarHandle& z, const VarHandle& y, const VarHandle& x) { - return x * y + z; - }); - - LoopNest l({b}); - // Flatten indices. - l.prepareForCodegen(); - auto bounds_info = inferBounds(l.root_stmt()); - - // There's only one buffer. - ASSERT_EQ(bounds_info.size(), 1); - auto& TABI = bounds_info[b.buf()][0]; - ASSERT_EQ(TABI.kind, TensorAccessKind::kStore); - // Flattened bounds should have a single dimension. - ASSERT_EQ(TABI.start.size(), 1); - ASSERT_EQ(TABI.stop.size(), 1); - - // Bounds should be 0 -> (3*4*5)-1 - ASSERT_TRUE(exprEquals(TABI.start[0], alloc(0))); - ASSERT_TRUE(exprEquals(TABI.stop[0], alloc(3 * 4 * 5 - 1))); -} - -TEST(BoundsInference, GetPotentialHazards) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - using namespace analysis; - - { - /* - * A[0] = B[0]; - * B[0] = 3; WAR on B - * A[0] = B[0]; WAW on A, RAW on B - * C[0] = 5; - */ - - StorePtr store1 = Store::make(a, {0}, Load::make(b, {0})); - StorePtr store2 = Store::make(b, {0}, 3); - StorePtr store3 = Store::make(a, {0}, Load::make(b, {0})); - StorePtr store4 = Store::make(c, {0}, 5); - StmtPtr stmt = Block::make({store1, store2, store3, store4}); - - MemDependencyChecker analyzer; - stmt->accept(&analyzer); - - ASSERT_EQ( - HazardKind::WriteAfterRead, - getPotentialHazards(analyzer, store1, store2)); - - ASSERT_EQ( - HazardKind::ReadAfterWrite, - getPotentialHazards(analyzer, store2, store3)); - - ASSERT_EQ( - HazardKind::WriteAfterWrite, - getPotentialHazards(analyzer, store1, store3)); - - // Fourth store has no dependencies - ASSERT_EQ( - HazardKind::NoDependency, - getPotentialHazards(analyzer, store1, store4)); - ASSERT_EQ( - HazardKind::NoDependency, - getPotentialHazards(analyzer, store2, store4)); - ASSERT_EQ( - HazardKind::NoDependency, - getPotentialHazards(analyzer, store3, store4)); - } -} - -TEST(BoundsInference, GetPotentialHazardsLoopNoHazard) { - Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { - return i * j; - }); - Tensor B = Compute("B", {64, 64}, [](const VarHandle& i, const VarHandle& j) { - return (i + 1) * (j + 1); - }); - - LoopNest l({A, B}); - - using namespace analysis; - - MemDependencyChecker analyzer; - l.root_stmt()->accept(&analyzer); - - ForPtr loopRootA = l.getLoopStmtsFor(A)[0]; - ForPtr loopRootB = l.getLoopStmtsFor(B)[0]; - - // No dependencies between loops. - ASSERT_EQ( - HazardKind::NoDependency, - getPotentialHazards(analyzer, loopRootA, loopRootB)); -} - -TEST(BoundsInference, GetPotentialHazardsLoopCall) { - Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { - return i * j; - }); - Tensor B = - Compute("B", {64, 64}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i, j) + 5; - }); - - LoopNest l({A, B}); - - using namespace analysis; - - MemDependencyChecker analyzer; - l.root_stmt()->accept(&analyzer); - - ForPtr loopRootA = l.getLoopStmtsFor(A)[0]; - ForPtr loopRootB = l.getLoopStmtsFor(B)[0]; - - ASSERT_EQ( - HazardKind::ReadAfterWrite, - getPotentialHazards(analyzer, loopRootA, loopRootB)); -} - -TEST(BoundsInference, GetPotentialHazardsLoopSplit) { - Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { - return i * j; - }); - - LoopNest l({A}); - ForPtr inner, tail; - - // Splitting with tail by something offset creates a tail which also writes to - // A. - ForPtr outer = l.getLoopStmtsFor(A)[0]; - // `outer` loop get transformed to the outer loop after splitting. - LoopNest::splitWithTail(outer, 5, &inner, &tail); - - using namespace analysis; - - MemDependencyChecker analyzer; - l.root_stmt()->accept(&analyzer); - - ASSERT_EQ( - HazardKind::WriteAfterWrite, getPotentialHazards(analyzer, outer, tail)); -} - -TEST(BoundsInference, HasConflictingOverlapSameBufferWithPartialOverlap) { - // Input IR: - // for (const auto j : c10::irange(10, 100)) { - // A[j] = 10 * j; - // } - // for (const auto k : c10::irange(10, 100)) { - // A[k-1] = 20 * k; - // } - BufHandle a_buf("A", {200}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = - For::make(k, 10, 100, Store::make(a_buf, {k - 1}, Mul::make(20, k))); - auto par = Block::make({forJ, forK}); - - tensorexpr::analysis::MemDependencyChecker analyzer; - par->accept(&analyzer); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ)); -} - -TEST(BoundsInference, HasConflictingOverlapSameBufferWithFullOverlap) { - // Input IR: - // for (const auto j : c10::irange(10, 100)) { - // A[j] = 10 * j; - // } - // for (const auto k : c10::irange(10, 100)) { - // A[k] = 20 * k; - // } - BufHandle a_buf("A", {200}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = For::make(k, 10, 100, Store::make(a_buf, {k}, Mul::make(20, k))); - auto par = Block::make({forJ, forK}); - - tensorexpr::analysis::MemDependencyChecker analyzer; - par->accept(&analyzer); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ)); -} - -TEST(BoundsInference, HasConflictingOverlapSameBufferWithFullOverlapRAW) { - // Input IR: - // for (const auto j : c10::irange(10, 100)) { - // A[j] = 10 * j; - // } - // for (const auto k : c10::irange(10, 100)) { - // B[k] = A[k]; - // } - BufHandle a_buf("A", {200}, kInt); - BufHandle b_buf("B", {200}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = - For::make(k, 10, 100, Store::make(b_buf, {k}, Load::make(a_buf, {k}))); - auto par = Block::make({forJ, forK}); - - tensorexpr::analysis::MemDependencyChecker analyzer; - par->accept(&analyzer); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ)); -} - -TEST(BoundsInference, HasConflictingOverlapSameBufferNotOverlapping) { - // Input IR: - // for (const auto j : c10::irange(10, 100)) { - // A[j] = 10 * j; - // } - // for (const auto k : c10::irange(10, 100)) { - // A[k+100] = 20 * k; - // } - BufHandle a_buf("A", {200}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = - For::make(k, 10, 100, Store::make(a_buf, {k + 100}, Mul::make(20, k))); - auto par = Block::make({forJ, forK}); - - tensorexpr::analysis::MemDependencyChecker analyzer; - par->accept(&analyzer); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, forK)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forK, forJ)); -} - -TEST(BoundsInference, HasConflictingOverlap2DBufferWithOverlap) { - // Input IR: - // for (const auto i : c10::irange(20)) { - // for (const auto j : c10::irange(100)) { - // A[i,j] = i * j * 500; - // } - // } - // for (const auto m : c10::irange(20)) { - // for (const auto n : c10::irange(50)) { - // A[m+1,n] = m + n * 100; - // } - // } - BufHandle a_buf("A", {20, 100}, kInt); - BufHandle b_buf("B", {20, 50}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)); - auto forJ = For::make(j, 0, 100, storeA1); - auto forI = For::make(i, 0, 20, forJ); - auto storeA2 = - Store::make(a_buf, {m + 1, n}, Add::make(m, Mul::make(n, 100))); - auto forN = For::make(n, 0, 50, storeA2); - auto forM = For::make(m, 0, 20, forN); - auto par = Block::make({forI, forM}); - - tensorexpr::analysis::MemDependencyChecker analyzer; - par->accept(&analyzer); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forI, forM)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forM, forI)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forN)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forN, forJ)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, storeA1, storeA2)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, storeA2, storeA1)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, storeA2)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, storeA1, forM)); -} - -TEST(BoundsInference, HasConflictingOverlap2DBufferWithNoOverlap) { - // Input IR: - // for (const auto i : c10::irange(20)) { - // for (const auto j : c10::irange(100)) { - // A[i,j] = i * j * 500; - // } - // } - // for (const auto m : c10::irange(20)) { - // for (const auto n : c10::irange(50)) { - // A[m+20,n+100] = m + n * 100; - // } - // } - BufHandle a_buf("A", {20, 100}, kInt); - BufHandle b_buf("B", {20, 50}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)); - auto forJ = For::make(j, 0, 100, storeA1); - auto forI = For::make(i, 0, 20, forJ); - auto storeA2 = - Store::make(a_buf, {m + 20, n + 100}, Add::make(m, Mul::make(n, 100))); - auto forN = For::make(n, 0, 50, storeA2); - auto forM = For::make(m, 0, 20, forN); - auto par = Block::make({forI, forM}); - - tensorexpr::analysis::MemDependencyChecker analyzer; - par->accept(&analyzer); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forI, forM)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forM, forI)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, forN)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forN, forJ)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA1, storeA2)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA2, storeA1)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, storeA2)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA1, forM)); -} - -TEST(BoundsInference, HasConflictingOverlapDifferentBuffers) { - // Input IR: - // for (const auto i : c10::irange(20)) { - // for (const auto j : c10::irange(100)) { - // A[i,j] = i * j * 500; - // } - // } - // for (const auto m : c10::irange(20)) { - // for (const auto n : c10::irange(50)) { - // B[m,n] = m + n * 100; - // } - // } - BufHandle a_buf("A", {20, 100}, kInt); - BufHandle b_buf("B", {20, 50}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)); - auto forJ = For::make(j, 0, 100, storeA1); - auto forI = For::make(i, 0, 20, forJ); - auto storeA2 = Store::make(b_buf, {m, n}, Add::make(m, Mul::make(n, 100))); - auto forN = For::make(n, 0, 50, storeA2); - auto forM = For::make(m, 0, 20, forN); - auto par = Block::make({forI, forM}); - - tensorexpr::analysis::MemDependencyChecker analyzer; - par->accept(&analyzer); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forI, forM)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forM, forI)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, forN)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forN, forJ)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA1, storeA2)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA2, storeA1)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, storeA2)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA1, forM)); -} - -TEST(BoundsInference, HasConflictingOverlapDueToRAWDependence) { - // Input IR: - // for (const auto j : c10::irange(100)) { - // A[j] = 10 * j; - // } - // for (const auto k : c10::irange(100)) { - // B[k] = 20 * A[99-k]; - // } - BufHandle a_buf("A", {100}, kInt); - BufHandle b_buf("B", {100}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = For::make( - k, - 0, - 100, - Store::make( - b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); - auto par = Block::make({forJ, forK}); - - tensorexpr::analysis::MemDependencyChecker analyzer; - par->accept(&analyzer); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ)); -} - -TEST(BoundsInference, HasConflictingOverlapDueToWARDependence) { - // Input IR: - // for (const auto k : c10::irange(100)) { - // B[k] = 20 * A[99-k]; - // } - // for (const auto j : c10::irange(100)) { - // A[j] = 10 * j; - // } - BufHandle a_buf("A", {100}, kInt); - BufHandle b_buf("B", {100}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forK = For::make( - k, - 0, - 100, - Store::make( - b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto par = Block::make({forK, forJ}); - - tensorexpr::analysis::MemDependencyChecker analyzer; - par->accept(&analyzer); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK)); - ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ)); -} - -TEST(BoundsInference, HasConflictingOverlapWithLoads) { - // Input IR: - // for (const auto k : c10::irange(10, 100)) { - // B[k] = 20 * A[99-k]; - // } - // for (const auto j : c10::irange(10, 100)) { - // C[j] = 10 * A[j]; - // } - BufHandle a_buf("A", {100}, kInt); - BufHandle b_buf("B", {100}, kInt); - BufHandle c_buf("C", {100}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forK = For::make( - k, - 10, - 100, - Store::make( - b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); - auto forJ = For::make( - j, - 10, - 100, - Store::make(c_buf, {j}, Mul::make(10, Load::make(a_buf, {j})))); - auto par = Block::make({forK, forJ}); - - tensorexpr::analysis::MemDependencyChecker analyzer; - par->accept(&analyzer); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, forK)); - ASSERT_FALSE(hasConflictingOverlap(analyzer, forK, forJ)); -} - -TEST(BoundsInference, IsOverlapping) { - // Input IR: - // for (const auto i : c10::irange(100)) { - // A[i] = i * 10; // storeA1 - // B[i] = A[99-i] * 20; // loadA1 - // C[i] = A[i + 100] * 10; // loadA2 - // A[i + 50] = i * 50; // storeA2 - // A[i + 150] = i * 150; // storeA3 - // } - BufHandle a_buf("A", {300}, kInt); - BufHandle b_buf("B", {100}, kInt); - BufHandle c_buf("C", {100}, kInt); - VarHandle i("i", kInt); - auto storeA1 = Store::make(a_buf, {i}, i * 10); - auto loadA1 = Load::make(a_buf, {ExprHandle(99) - i}); - auto storeB = Store::make(b_buf, {i}, Mul::make(loadA1, 20)); - auto loadA2 = Load::make(a_buf, {i + 100}); - auto storeC = Store::make(c_buf, {i}, Mul::make(loadA2, 10)); - auto storeA2 = Store::make(a_buf, {i + 50}, i * 50); - auto storeA3 = Store::make(a_buf, {i + 150}, i * 150); - auto forI = For::make( - i, 0, 100, Block::make({storeA1, storeB, storeC, storeA2, storeA3})); - tensorexpr::analysis::MemDependencyChecker analyzer; - forI->accept(&analyzer); - ASSERT_TRUE(isOverlapping(analyzer, storeA1, to(loadA1.node()))); - ASSERT_FALSE(isOverlapping(analyzer, storeA1, to(loadA2.node()))); - ASSERT_TRUE(isOverlapping(analyzer, storeA1, storeA2)); - ASSERT_FALSE(isOverlapping(analyzer, storeA1, storeA3)); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_conv.cpp b/test/cpp/tensorexpr/test_conv.cpp deleted file mode 100644 index e72303873a6c..000000000000 --- a/test/cpp/tensorexpr/test_conv.cpp +++ /dev/null @@ -1,234 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { - -namespace te = torch::jit::tensorexpr; -namespace F = torch::nn::functional; - -#ifdef TORCH_ENABLE_LLVM - -// Generate test data with few bits of precision, to minimize error -// accumulation from floating-point reordering. -static at::Tensor genTestData(c10::IntArrayRef args) { - return at::trunc(at::randn(args) * 256.0f) / 256.0f; -} - -TEST(Conv, DepthwiseConv2D) { - constexpr int N = 1, C = 72, H = 56, W = 56; - constexpr int K = 72, R = 3, S = 3; - constexpr int kPad = 1, kStride = 2, kGroups = C; - constexpr int CperG = C / kGroups; - - te::BufHandle input("input", {N, C, H, W}, te::kFloat); - te::BufHandle weight("weight", {K, CperG, R, S}, te::kFloat); - te::BufHandle bias("bias", {K}, te::kFloat); - te::Tensor output = - te::conv2d_depthwise(input, weight, bias, kStride, kPad, kGroups); - - te::LoopNest loop({output}); - loop.simplify(); - loop.prepareForCodegen(); - te::LLVMCodeGen cg(loop.root_stmt(), {input, weight, bias, output}); - - auto it = genTestData({N, C, H, W}); - auto wt = genTestData({K, CperG, R, S}); - auto bt = genTestData({K}); - auto ref = at::conv2d(it, wt, bt, kStride, kPad, /*dilation=*/1, kGroups); - auto ot = at::zeros_like(ref); - cg.call( - {it.data_ptr(), - wt.data_ptr(), - bt.data_ptr(), - ot.data_ptr()}); - - ASSERT_TRUE(at::allclose(ref, ot)); -} - -TEST(Conv, DepthwiseConv2DNoBias) { - constexpr int N = 1, C = 72, H = 56, W = 56; - constexpr int K = 72, R = 3, S = 3; - constexpr int kPad = 1, kStride = 2, kGroups = C; - constexpr int CperG = C / kGroups; - - te::BufHandle input("input", {N, C, H, W}, te::kFloat); - te::BufHandle weight("weight", {K, CperG, R, S}, te::kFloat); - te::Tensor output = - te::conv2d_depthwise(input, weight, kStride, kPad, kGroups); - - te::LoopNest loop({output}); - loop.simplify(); - loop.prepareForCodegen(); - te::LLVMCodeGen cg(loop.root_stmt(), {input, weight, output}); - - auto it = genTestData({N, C, H, W}); - auto wt = genTestData({K, CperG, R, S}); - auto ref = - at::conv2d(it, wt, at::Tensor(), kStride, kPad, /*dilation=*/1, kGroups); - auto ot = at::zeros_like(ref); - cg.call({it.data_ptr(), wt.data_ptr(), ot.data_ptr()}); - - ASSERT_TRUE(at::allclose(ref, ot)); -} - -TEST(Conv, DepthwiseConv2DDynamicShapes) { - te::VarHandle N_var("N", te::kInt); - te::VarHandle C_var("C", te::kInt); - te::VarHandle H_var("H", te::kInt); - te::VarHandle W_var("W", te::kInt); - te::VarHandle K_var("K", te::kInt); - te::VarHandle CperG_var("CperG", te::kInt); - te::VarHandle R_var("R", te::kInt); - te::VarHandle S_var("S", te::kInt); - te::VarHandle kPad_var("kPad", te::kInt); - te::VarHandle kStride_var("kStride", te::kInt); - te::VarHandle kGroups_var("kGroups", te::kInt); - - te::BufHandle input("input", {N_var, C_var, H_var, W_var}, te::kFloat); - te::BufHandle weight("weight", {K_var, CperG_var, R_var, S_var}, te::kFloat); - te::Tensor output = te::conv2d_depthwise( - input, - weight, - N_var, - C_var, - H_var, - W_var, - K_var, - CperG_var, - R_var, - S_var, - kStride_var, - kPad_var, - kGroups_var); - - te::LoopNest loop({output}); - loop.simplify(); - loop.prepareForCodegen(); - std::vector buffer_args = { - input, - weight, - N_var, - C_var, - H_var, - W_var, - K_var, - CperG_var, - R_var, - S_var, - kPad_var, - kStride_var, - kGroups_var, - output}; - te::LLVMCodeGen cg(loop.root_stmt(), buffer_args); - - constexpr int N = 1, C = 72, H = 56, W = 56; - constexpr int K = 72, R = 3, S = 3; - constexpr int kPad = 1, kStride = 2, kGroups = C; - constexpr int CperG = C / kGroups; - - auto it = genTestData({N, C, H, W}); - auto wt = genTestData({K, CperG, R, S}); - auto ref = - at::conv2d(it, wt, at::Tensor(), kStride, kPad, /*dilation=*/1, kGroups); - auto ot = at::zeros_like(ref); - std::vector call_args = { - it.data_ptr(), - wt.data_ptr(), - N, - C, - H, - W, - K, - CperG, - R, - S, - kPad, - kStride, - kGroups, - ot.data_ptr()}; - cg.call(call_args); - - ASSERT_TRUE(at::allclose(ref, ot)); -} - -#endif - -TEST(Conv, Conv2D) { - // Input dimensions. - constexpr int N = 1; - constexpr int C = 3; - constexpr int H = 11; - constexpr int W = 11; - - // Filter dimensions. - constexpr int K = 8; - constexpr int R = 3; - constexpr int S = 3; - - // Output dims. - constexpr int OH = H - R + 1; - constexpr int OW = W - S + 1; - - // Compute reference result. - at::Tensor input = torch::randn({N, C, H, W}); - at::Tensor filter = torch::randn({K, C, R, S}); - at::Tensor ref = F::conv2d(input, filter); - - // Double check the output size is as expected. - ASSERT_EQ(ref.size(0), N); - ASSERT_EQ(ref.size(1), K); - ASSERT_EQ(ref.size(2), OH); - ASSERT_EQ(ref.size(3), OW); - - te::BufHandle inputB("input", {N, C, H, W}, te::kFloat); - te::BufHandle filterB("filter", {K, C, R, S}, te::kFloat); - - te::Tensor conv = te::Reduce( - "conv", - {N, K, OH, OW}, - te::Sum(), - // FIXME: We have to use a `std::vector` parameter here and then unpack - // it, because we don't have an overload allowing for an arbitrary number - // of ExprHandle/VarHandle parameters. - [&](const std::vector& v) { - auto const& n = v[0]; - auto const& k = v[1]; - auto const& oh = v[2]; - auto const& ow = v[3]; - auto const& c = v[4]; - auto const& r = v[5]; - auto const& s = v[6]; - // FIXME: We have to use `call` and construct a `std::vector` here - // because the `operator()` overload is only specialized for a small - // number of arguments. - return inputB.load(n, c, oh + r, ow + s) * filterB.load(k, c, r, s); - }, - // FIXME: If you forget one of the reduction dims, you get a segfault. - // Could that be caught by a verifier? - {C, R, S}); - - // FIXME: It'd be nice to have a single header that pulls in things like - // LoopNest, IRSimplifier, etc. - te::LoopNest loop({conv}); - loop.prepareForCodegen(); - te::StmtPtr s = loop.root_stmt(); - s = te::IRSimplifier::simplify(s); - - at::Tensor result = at::empty_like(ref); - te::SimpleIREvaluator cg(s, {inputB, filterB, conv}); - cg.call( - {input.data_ptr(), - filter.data_ptr(), - result.data_ptr()}); - - ASSERT_TRUE(at::allclose(ref, result, 1e-3, 1e-3)); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_cpp_codegen.cpp b/test/cpp/tensorexpr/test_cpp_codegen.cpp deleted file mode 100644 index ed7679053637..000000000000 --- a/test/cpp/tensorexpr/test_cpp_codegen.cpp +++ /dev/null @@ -1,259 +0,0 @@ -#include - -#include "test/cpp/tensorexpr/test_base.h" - -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -#define STR_CHECK(node, expected) \ - std::stringstream ss; \ - CppPrinter printer(&ss); \ - printer.visit(node); \ - ASSERT_EQ(ss.str(), expected) - -#define FILE_CHECK(node, pattern) \ - std::stringstream ss; \ - CppPrinter printer(&ss); \ - printer.visit(node); \ - torch::jit::testing::FileCheck().run(pattern, ss.str()) - -TEST(CppPrinter, IntImm) { - auto i = alloc(10); - STR_CHECK(i, "10"); -} - -TEST(CppPrinter, FloatImm) { - auto f = alloc(10); - STR_CHECK(f, "10.f"); -} - -TEST(CppPrinter, FloatImm1) { - auto f = alloc(10); - STR_CHECK(f, "10.f"); -} - -TEST(CppPrinter, DoubleImm) { - auto d = alloc(10); - STR_CHECK(d, "10.0"); -} - -TEST(CppPrinter, DoubleImm1) { - auto d = alloc(10.1); - STR_CHECK(d, "10.1"); -} - -TEST(CppPrinter, HalfImm) { - auto h = alloc(10); - STR_CHECK(h, "10"); -} - -TEST(CppPrinter, Add) { - auto add = alloc(alloc(1), alloc(2)); - STR_CHECK(add, "1 + 2"); -} - -TEST(CppPrinter, AddExpr1) { - auto add = alloc( - alloc(alloc(0), alloc(1)), - alloc(alloc(2), alloc(3))); - STR_CHECK(add, "(0 + 1) + (2 - 3)"); -} - -TEST(CppPrinter, AddExpr2) { - auto add = alloc( - alloc(alloc(0), alloc(1)), - alloc(alloc(2), alloc(3))); - STR_CHECK(add, "0 * 1 + (2 - 3)"); -} - -TEST(CppPrinter, AddExpr3) { - auto add = alloc( - alloc(alloc(0), alloc(1)), - alloc
(alloc(2), alloc(3))); - STR_CHECK(add, "(0 + 1) + 2 / 3"); -} - -TEST(CppPrinter, Mod) { - auto mod = alloc(alloc(1), alloc(2)); - STR_CHECK(mod, "1 % 2"); -} - -TEST(CppPrinter, ModFloat) { - auto mod = alloc(alloc(1), alloc(2)); - STR_CHECK(mod, "std::fmod(1.f, 2.f)"); -} - -TEST(CppPrinter, Max) { - auto max = alloc(alloc(1), alloc(2), false); - STR_CHECK(max, "std::max(1, 2)"); -} - -TEST(CppPrinter, MaxFloat) { - auto max = alloc(alloc(1), alloc(2), false); - STR_CHECK(max, "std::max(1.f, 2.f)"); -} - -TEST(CppPrinter, MaxHalf) { - auto max = alloc(alloc(1), alloc(2), false); - STR_CHECK(max, "(1 < 2) ? 2 : 1"); -} - -TEST(CppPrinter, And) { - auto v = alloc(alloc(1), alloc(2)); - STR_CHECK(v, "1 & 2"); -} - -TEST(CppPrinter, CompareSelect) { - auto cs = alloc( - alloc(1), - alloc(2), - alloc(1), - alloc(2), - CompareSelectOperation::kLE); - STR_CHECK(cs, "((1 <= 2) ? 1.f : 2.f)"); -} - -TEST(CppPrinter, IfThenElse) { - auto cond = alloc(alloc(1), alloc(2)); - auto true_value = alloc(alloc(0), alloc(1)); - auto false_value = alloc(alloc(2), alloc(3)); - auto v = alloc(cond, true_value, false_value); - STR_CHECK(v, "((1 + 2) ? 0 - 1 : 2 * 3)"); -} - -TEST(CppPrinter, AllocateFree) { - BufHandle buf("x", {2, 3}, kInt); - AllocatePtr alloc = Allocate::make(buf); - FreePtr free = Free::make(buf); - BlockPtr block = Block::make({alloc, free}); - - const std::string pattern = R"( - # CHECK: { - # CHECK: int* x = static_cast(malloc(24)); - # CHECK: free(x); - # CHECK: } - )"; - FILE_CHECK(block, pattern); -} - -TEST(CppPrinter, LoadStore) { - BufHandle a("A", {2, 3}, kInt); - BufHandle b("B", {3, 4}, kInt); - auto store = b.store({2, 2}, a.load(1, 1)); - STR_CHECK( - store, "B[(0 + 2 * (1 * 4)) + 2 * 1] = A[(0 + 1 * (1 * 3)) + 1 * 1];\n"); -} - -TEST(CppPrinter, Var) { - auto var = alloc("x", kInt); - STR_CHECK(var, "x"); -} - -TEST(CppPrinter, Cast) { - auto cast = alloc(kFloat, alloc(1)); - STR_CHECK(cast, "static_cast(1)"); -} - -TEST(CppPrinter, BitCast) { - auto cast = alloc(kInt, alloc(20)); - STR_CHECK(cast, "std::bitcast(20.f)"); -} - -TEST(CppPrinter, Let) { - auto var = alloc("x", kFloat); - auto val = alloc(2); - auto let = alloc(var, val); - STR_CHECK(let, "float x = 2.f;\n"); -} - -TEST(CppPrinter, For) { - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - VarHandle i("i", kInt); - auto f = For::make(i, 0, N, c.store({i}, Add::make(a.load(i), b.load(i)))); - const std::string pattern = R"( - # CHECK: for (int i = 0; i < 1024; i++) { - # CHECK: C[i] = (A[i]) + (B[i]); - # CHECK: } - )"; - FILE_CHECK(f, pattern); -} - -TEST(CppPrinter, Cond) { - BufHandle x("X", {1}, kInt); - auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); - auto cond = - Cond::make(cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1)); - const std::string pattern = R"( - # CHECK: if (((X[0] < 10) ? 1 : 0)) { - # CHECK: X[0] = (X[0]) + 1; - # CHECK: } else { - # CHECK: X[0] = (X[0]) - 1; - # CHECK: } - )"; - FILE_CHECK(cond, pattern); -} - -TEST(CppPrinter, Intrinsics) { - const std::unordered_set> unsupported_ops{ - kRand, kSigmoid}; - for (const auto i : c10::irange(static_cast(kMaxIntrinsicsOp))) { - IntrinsicsOp op = static_cast(i); - if (unsupported_ops.count(op)) { - continue; - } - - if (Intrinsics::OpArgCount(op) == 1) { - auto v = alloc(op, alloc(2.0f)); - STR_CHECK(v, "std::" + v->func_name() + "(2.f)"); - } else { - auto v = - alloc(op, alloc(1.0f), alloc(2.0f)); - STR_CHECK(v, "std::" + v->func_name() + "(1.f, 2.f)"); - } - } -} - -TEST(CppPrinter, ExternalCall) { - std::vector dims{alloc(2), alloc(2)}; - auto output = alloc("out", dims, kFloat); - auto buf_arg1 = alloc("a", dims, kFloat); - auto buf_arg2 = alloc("b", dims, kFloat); - auto scalar_arg = alloc(alloc(1), alloc(2)); - std::vector buf_args{buf_arg1, buf_arg2}; - std::vector scalar_args{scalar_arg}; - auto call = - alloc(output, "nnc_aten_matmul", buf_args, scalar_args); - const std::string pattern = R"( - # CHECK: { - # CHECK: void* buf_ptrs[]{out, a, b}; - # CHECK: int64_t buf_ranks[]{2, 2, 2}; - # CHECK: int64_t buf_dims[]{2, 2, 2, 2, 2, 2}; - # CHECK: int8_t buf_dtypes[]{6, 6, 6}; - # CHECK: int64_t extra_args[]{1 + 2}; - # CHECK: nnc_aten_matmul( - # CHECK: 3, - # CHECK: buf_ptrs, - # CHECK: buf_ranks, - # CHECK: buf_dims, - # CHECK: buf_dtypes, - # CHECK: 1, - # CHECK: extra_args); - # CHECK: } - )"; - FILE_CHECK(call, pattern); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_cuda.cpp b/test/cpp/tensorexpr/test_cuda.cpp deleted file mode 100644 index 2e1e84e758db..000000000000 --- a/test/cpp/tensorexpr/test_cuda.cpp +++ /dev/null @@ -1,2344 +0,0 @@ -#ifdef USE_CUDA - -#include -#include -#include - -#include - -#include - -#include -#include -#include -#include -#include -#include - -#include - -#include -#include -#include - -namespace torch { -namespace jit { -using namespace torch::jit::tensorexpr; -using namespace torch::jit::tensorexpr; - -template -static void testCudaTestVectorAdd01_impl() { - const int num_iter = 3; - const int block_count = 16; - const int block_size = 128; - Dtype dtype = ToDtype(); - BufHandle a_buf("a", {num_iter, block_count, block_size}, dtype); - BufHandle b_buf("b", {num_iter, block_count, block_size}, dtype); - Tensor c = Compute( - "c", - { - num_iter, - block_count, - block_size, - }, - [&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) { - return a_buf.load(n, b_id, t_id) + b_buf.load(n, b_id, t_id); - }); - LoopNest l({c}); - std::vector loops = l.getLoopStmtsFor(c); - loops[1]->set_gpu_block_index(0); - loops[2]->set_gpu_thread_index(0); - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf); - const int N = block_count * block_size * num_iter; - PaddedBuffer a_v(N); - PaddedBuffer b_v(N); - PaddedBuffer c_v(N); - PaddedBuffer c_ref(N); - - for (const auto i : c10::irange(N)) { - a_v(i) = ctype(i); - b_v(i) = ctype(i * 3 + 7); - c_ref(i) = a_v(i) + b_v(i); - } - - // TODO: move gpu support into PaddedBuffer - ctype* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, N * sizeof(ctype))); - ctype* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, N * sizeof(ctype))); - ctype* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, N * sizeof(ctype))); - C10_CUDA_CHECK( - cudaMemcpy(a_dev, a_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK( - cudaMemcpy(b_dev, b_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK( - cudaMemcpy(c_dev, c_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK( - cudaMemcpy(c_v.data(), c_dev, N * sizeof(ctype), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); -} - -float sigmoid(float x) { - return 1.0f / (1.0f + expf(-0.0f - x)); -} - -TEST(Cuda, Sigmoid_CUDA) { - const int num_iter = 3; - const int block_count = 16; - const int block_size = 128; - Dtype dtype = ToDtype(); - BufHandle a_buf("a", {num_iter, block_count, block_size}, dtype); - Tensor c = Compute( - "c", - { - num_iter, - block_count, - block_size, - }, - [&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) { - return sigmoid(sigmoid(a_buf.load(n, b_id, t_id))); - }); - LoopNest l({c}); - std::vector loops = l.getLoopStmtsFor(c); - loops[1]->set_gpu_block_index(0); - loops[2]->set_gpu_thread_index(0); - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, a_buf); - const int N = block_count * block_size * num_iter; - PaddedBuffer a_v(N); - PaddedBuffer c_v(N); - PaddedBuffer c_ref(N); - - for (const auto i : c10::irange(N)) { - a_v(i) = float(i); - c_ref(i) = sigmoid(sigmoid(a_v(i))); - } - - // TODO: move gpu support into PaddedBuffer - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, N * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, N * sizeof(float))); - C10_CUDA_CHECK( - cudaMemcpy(a_dev, a_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK( - cudaMemcpy(c_dev, c_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, a_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK( - cudaMemcpy(c_v.data(), c_dev, N * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); -} - -TEST(Cuda, TestVectorAdd01_CUDA) { - // floating types. - testCudaTestVectorAdd01_impl(); - testCudaTestVectorAdd01_impl(); - testCudaTestVectorAdd01_impl(); - - // integer types. - testCudaTestVectorAdd01_impl(); - testCudaTestVectorAdd01_impl(); - testCudaTestVectorAdd01_impl(); - testCudaTestVectorAdd01_impl(); - testCudaTestVectorAdd01_impl(); -} - -static void testCudaTestVectorAdd02_impl(int64_t N, int64_t block_size) { - BufHandle a_buf("a", {N}, kFloat); - BufHandle b_buf("b", {N}, kFloat); - Tensor c = Compute("c", {N}, [&](const VarHandle& n) { - return a_buf.load(n) + b_buf.load(n); - }); - LoopNest l({c}); - ForPtr n_inner; - std::vector loops = l.getLoopStmtsFor(c); - l.splitWithMask(loops[0], block_size, &n_inner); - loops[0]->set_gpu_block_index(0); - n_inner->set_gpu_thread_index(0); - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf); - PaddedBuffer a_v(N); - PaddedBuffer b_v(N); - PaddedBuffer c_v(N); - PaddedBuffer c_ref(N); - - for (const auto i : c10::irange(N)) { - a_v(i) = i; - b_v(i) = i * 3 + 7; - c_ref(i) = a_v(i) + b_v(i); - } - - // TODO: move gpu support into PaddedBuffer - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, N * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, N * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, N * sizeof(float))); - C10_CUDA_CHECK( - cudaMemcpy(a_dev, a_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK( - cudaMemcpy(b_dev, b_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK( - cudaMemcpy(c_dev, c_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK( - cudaMemcpy(c_v.data(), c_dev, N * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); -} - -TEST(Cuda, TestVectorAdd02_CUDA) { - testCudaTestVectorAdd02_impl(1024, 128); - testCudaTestVectorAdd02_impl(1030, 128); -} - -TEST(Cuda, HalfCast_CUDA) { - auto half = ToDtype(); - BufHandle a("a", {4}, half); - Tensor b = Compute("b", {4}, [&](const VarHandle& i) { - return Cast::make(kFloat, a.load(i)); - }); - - LoopNest l({b}); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - CudaCodeGen cg(s, {a, b}); - - std::vector aData(4, 2.0f); - std::vector bData(4, 0.0f); - at::Half* aDev = nullptr; - float* bDev = nullptr; - auto aSize = aData.size() * sizeof(aData[0]); - auto bSize = bData.size() * sizeof(bData[0]); - - C10_CUDA_CHECK(cudaMalloc(&aDev, aSize)); - C10_CUDA_CHECK(cudaMalloc(&bDev, bSize)); - C10_CUDA_CHECK(cudaMemcpy(aDev, aData.data(), aSize, cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy(bDev, bData.data(), bSize, cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cg.call({aDev, bDev}); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - C10_CUDA_CHECK(cudaMemcpy(aData.data(), aDev, aSize, cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy(bData.data(), bDev, bSize, cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - assertAllEqual(bData, 2.0f); - - C10_CUDA_CHECK(cudaFree(aDev)); - C10_CUDA_CHECK(cudaFree(bDev)); -} - -TEST(Cuda, DynamicShape2D_CUDA) { - auto testWithSize = [](int32_t M, int32_t N) { - VarHandle m("m", kInt); - VarHandle n("n", kInt); - BufHandle a("a", {m, n}, kFloat); - BufHandle b("b", {m, n}, kFloat); - Tensor c = - Compute("c", {m, n}, [&](const VarHandle& i, const VarHandle& j) { - return a.load(i, j) + b.load(i, j); - }); - LoopNest l({c}); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - CudaCodeGen cg(s, {a, b, c, m, n}); - - std::vector aData(M * N, 1.0f); - std::vector bData(M * N, 2.0f); - std::vector cData(M * N, 0.0f); - float* aDev = nullptr; - float* bDev = nullptr; - float* cDev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&aDev, aData.size() * sizeof(aData[0]))); - C10_CUDA_CHECK(cudaMalloc(&bDev, bData.size() * sizeof(bData[0]))); - C10_CUDA_CHECK(cudaMalloc(&cDev, cData.size() * sizeof(cData[0]))); - C10_CUDA_CHECK(cudaMemcpy( - aDev, - aData.data(), - aData.size() * sizeof(aData[0]), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - bDev, - bData.data(), - bData.size() * sizeof(bData[0]), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - cDev, - cData.data(), - cData.size() * sizeof(cData[0]), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cg.call({aDev, bDev, cDev, M, N}); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - C10_CUDA_CHECK(cudaMemcpy( - cData.data(), - cDev, - cData.size() * sizeof(cData[0]), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(cData, std::vector(M * N, 3.0f), 1e-7); - - C10_CUDA_CHECK(cudaFree(aDev)); - C10_CUDA_CHECK(cudaFree(bDev)); - C10_CUDA_CHECK(cudaFree(cDev)); - }; - testWithSize(32, 32); - testWithSize(1, 16); - testWithSize(27, 13); -} - -TEST(Cuda, TestRand01_CUDA) { - const int num_iter = 3; - const int block_count = 16; - const int block_size = 128; - Tensor c = Compute( - "c", - { - num_iter, - block_count, - block_size, - }, - [&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) { - return Intrinsics::make(IntrinsicsOp::kRand, kFloat); - }); - LoopNest l({c}); - std::vector loops = l.getLoopStmtsFor(c); - loops[1]->set_gpu_block_index(0); - loops[2]->set_gpu_thread_index(0); - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c); - const int N = block_count * block_size * num_iter; - PaddedBuffer c_v(N); - - // TODO: move gpu support into PaddedBuffer - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, N * sizeof(float))); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK( - cudaMemcpy(c_v.data(), c_dev, N * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - float sum1 = 0; - float sum2 = 0; - float sum3 = 0; - for (const auto i : c10::irange(N)) { - float v = c_v.data()[i]; - sum1 += v; - sum2 += v * v; - sum3 += v * v * v; - ASSERT_TRUE(v >= 0 && v < 1); - } - sum1 /= N; - sum2 /= N; - sum3 /= N; - float sum1_mean = 1.f / 2; - float sum2_mean = 1.f / 3; - float sum3_mean = 1.f / 4; - - ASSERT_NEAR(sum1, sum1_mean, 2e-2); - ASSERT_NEAR(sum2, sum2_mean, 2e-2); - ASSERT_NEAR(sum3, sum3_mean, 2e-2); - C10_CUDA_CHECK(cudaFree(c_dev)); -} - -TEST(Cuda, DynamicShapeSplit_CUDA) { - constexpr int64_t N = 4096; - VarHandle n("n", kLong); - BufHandle a("a", {n}, kFloat); - Tensor b = - Compute("b", {n}, [&](const VarHandle& i) { return a.load(i) * 2.0f; }); - LoopNest l({b}); - ForPtr inner; - std::vector loops = l.getLoopStmtsFor(b); - l.splitWithMask(loops[0], 1024, &inner); - loops[0]->set_gpu_block_index(0); - inner->set_gpu_thread_index(0); - StmtPtr s = l.root_stmt(); - CudaCodeGen cg(s, {a, b, n}); - - std::vector aData(N, 1.0f); - std::vector bData(N, 1.0f); - float* aDev = nullptr; - float* bDev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&aDev, aData.size() * sizeof(aData[0]))); - C10_CUDA_CHECK(cudaMalloc(&bDev, bData.size() * sizeof(bData[0]))); - C10_CUDA_CHECK(cudaMemcpy( - aDev, - aData.data(), - aData.size() * sizeof(aData[0]), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - bDev, - bData.data(), - bData.size() * sizeof(aData[0]), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cg.call({aDev, bDev, N}); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - C10_CUDA_CHECK(cudaMemcpy( - bData.data(), - bDev, - bData.size() * sizeof(aData[0]), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(bData, std::vector(N, 2.0f), 1e-7); - - C10_CUDA_CHECK(cudaFree(aDev)); - C10_CUDA_CHECK(cudaFree(bDev)); -} - -TEST(Cuda, OneBlockOneThreadGlobalReduce1_CUDA) { - const static int N = 1024; - BufHandle data_buf("data", {N}, kFloat); - BufHandle output_buf("output", {1}, kFloat); - - // The test adds the following code for trivial reduction: - // for (const auto bidx : c10::irange(1)) { // blockIdx.x - // for (const auto tidx : c10::irange(1)) { // threadIdx.x - // output[0] = 0.f; - // for (const auto i1 : c10::irange(1024)) { - // output[0] = output[0] + data[i1]; - // } - // } - // } - - StorePtr init_store = output_buf.store({0}, 0.f); - VarHandle i1("i1", kInt); - ExprHandle load_data = Load::make(data_buf, {i1}); - ExprHandle load_output = Load::make(output_buf, {0}); - ExprHandle add_value = load_output + load_data; - StorePtr store_output = output_buf.store({0}, add_value); - ForPtr for_output = For::make(i1, 0, N, store_output); - StmtPtr reduce_block = Block::make({init_store, for_output}); - VarHandle thread_idx("tidx", kInt); - LoopOptions thread_idx_options; - thread_idx_options.set_gpu_thread_index(0); - ForPtr thread_idx_loop = - For::make(thread_idx, 0, 1, reduce_block, thread_idx_options); - VarHandle block_idx("bidx", kInt); - LoopOptions block_idx_options; - block_idx_options.set_gpu_block_index(0); - ForPtr block_idx_loop = - For::make(block_idx, 0, 1, thread_idx_loop, block_idx_options); - - CudaCodeGen cuda_cg(block_idx_loop, data_buf, output_buf); - PaddedBuffer data_v(N); - PaddedBuffer output_v(1, "output_v"); - PaddedBuffer output_ref(1, "output_ref"); - - output_ref(0) = 0; - for (const auto i : c10::irange(N)) { - data_v(i) = i; - output_ref(0) += data_v(i); - } - - float* data_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&data_dev, N * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - data_dev, data_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); - float* output_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&output_dev, 1 * sizeof(float))); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(data_dev, output_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - output_v.data(), output_dev, 1 * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(output_v, output_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(data_dev)); - C10_CUDA_CHECK(cudaFree(output_dev)); -} - -TEST(Cuda, OneBlockMultiThreadGlobalReduce1_CUDA) { - const static int N = 1024; - - // This test does the following reduction: - // clang-format off - // for b in 0..1 // block-idx - // for t in 0..1024: // thread-idx - // if t < 1: - // b[0] = 0 - // // implied sync_threads - // for t in 0..1024: // thread-idx - // b[0] = b[0] + a[t] // implied atomic - // clang-format on - - BufHandle a_buf("a", {N}, kFloat); - BufHandle b_buf("b", {1}, kFloat); - - StorePtr init_store = b_buf.store({0}, 0.f); - VarHandle t("t", kInt); - VarHandle b("b", kInt); - - // for t in 0..1024: // thread-idx - // if t < 1: - // b[0] = 0 - ExprHandle cond_t_lt_1 = - CompareSelect::make(t, 1, CompareSelectOperation::kLT); - CondPtr masked_init_b = Cond::make(cond_t_lt_1, init_store, nullptr); - LoopOptions thread_idx_options; - thread_idx_options.set_gpu_thread_index(0); - ForPtr for_init = For::make(t, 0, N, masked_init_b, thread_idx_options); - - // for t in 0..1024: // thread-idx - // b[0] = b[0] + a[t] // implied atomic - ExprHandle load_a = Load::make(a_buf, {t}); - ExprHandle load_b = Load::make(b_buf, {0}); - ExprHandle add_value = load_b + load_a; - StorePtr store_b = b_buf.store({0}, add_value); - ForPtr for_b = For::make(t, 0, N, store_b, thread_idx_options); - - StmtPtr reduce_block = Block::make({for_init, for_b}); - - VarHandle block_idx("bidx", kInt); - LoopOptions block_idx_options; - block_idx_options.set_gpu_block_index(0); - ForPtr block_idx_loop = - For::make(block_idx, 0, 1, reduce_block, block_idx_options); - - CudaCodeGen cuda_cg(block_idx_loop, a_buf, b_buf); - PaddedBuffer a_v(N); - PaddedBuffer b_v(1, "b_v"); - PaddedBuffer b_ref(1, "b_ref"); - - b_ref(0) = 0; - for (const auto i : c10::irange(N)) { - a_v(i) = i; - b_ref(0) += a_v(i); - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, N * sizeof(float))); - C10_CUDA_CHECK( - cudaMemcpy(a_dev, a_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, 1 * sizeof(float))); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK( - cudaMemcpy(b_v.data(), b_dev, 1 * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(b_v, b_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); -} - -TEST(Cuda, NoThreadIdxWrite_1_CUDA) { - // This test does the following reduction: - // - // for k in 0..1: // block-idx - // a[0] = 0 - // for n in 0..2: - // a[0] = a[0] + n - // for m in 0..1024: // thread-idx - // b[m] = m - // a[1] = 1 - // for l in 0..2: - // a[1] = a[1] + n - // - // note that the statements not covered by thread-idx are supposed to be - // covered by its own thread-idx - - const static int N = 1024; - BufHandle a_buf("a", {2}, kFloat); - BufHandle b_buf("b", {N}, kFloat); - - VarHandle k("k", kInt); - VarHandle l("l", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - - // a[0] = 0 - // for n in 0..2: - // a[0] = a[0] + n - StorePtr store_a0_0 = a_buf.store({0}, 0.f); - ExprHandle load_a0 = Load::make(a_buf, {0}); - ExprHandle v1 = load_a0 + n; - StorePtr store_a0_v1 = a_buf.store({0}, v1); - ForPtr loop_a_0 = For::make(n, 0, 2, store_a0_v1); - - // for m in 0..1024: // thread-idx - // b[m] = m - StorePtr store_bm_m = b_buf.store({m}, m + 0.f); - LoopOptions thread_idx_options; - thread_idx_options.set_gpu_thread_index(0); - ForPtr loop_b_1 = For::make(m, 0, N, store_bm_m, thread_idx_options); - - // a[1] = 1 - // for l in 0..2: - // a[1] = a[1] + l - StorePtr store_a1_1 = a_buf.store({1}, 1.f); - ExprHandle load_a1 = a_buf.load(1); - ExprHandle v2 = load_a1 + l; - StorePtr store_a1_v2 = a_buf.store({1}, v2); - ForPtr loop_a_1 = For::make(l, 0, 2, store_a1_v2); - - StmtPtr reduce_block = - Block::make({store_a0_0, loop_a_0, loop_b_1, store_a1_1, loop_a_1}); - - VarHandle block_idx("bidx", kInt); - LoopOptions block_idx_options; - block_idx_options.set_gpu_block_index(0); - ForPtr block_idx_loop = - For::make(block_idx, 0, 1, reduce_block, block_idx_options); - - CudaCodeGen cuda_cg(block_idx_loop, a_buf, b_buf); - PaddedBuffer a_v(2); - PaddedBuffer b_v(N, "b_v"); - PaddedBuffer a_ref(2, "a_ref"); - PaddedBuffer b_ref(N, "b_ref"); - - a_ref(0) = 0; - for (const auto i : c10::irange(2)) { - a_ref(0) += i; - } - a_ref(1) = a_ref(0) + 1; - for (const auto i : c10::irange(N)) { - b_ref(i) = i; - } - - // TODO: add check of the generated code. - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, 2 * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, N * sizeof(float))); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK( - cudaMemcpy(a_v.data(), a_dev, 2 * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK( - cudaMemcpy(b_v.data(), b_dev, N * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(a_v, a_ref, 1e-5); - ExpectAllNear(b_v, b_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); -} - -TEST(Cuda, SharedMemReduce_1_CUDA) { - // FIXME: this test is flaky in CI. - // This test does the following: - // for k in 0..1: // block-idx - // alloc(c, 64) - // for n in 0..64: // thread-idx - // c(n) = 0 - // for m in 0..128: - // for n in 0..64: // thread_idx - // c(n) = c(n) + a(k, m, n) - // b(k) = 0 - // for n in 0..64: // thread_idx - // b(k) = b(k) + c(n) - // free(c) - - const int M = 128; - const int N = 64; - const int kTotalSize = M * N; - LoopOptions thread_idx_opt; - thread_idx_opt.set_gpu_thread_index(0); - LoopOptions block_idx_opt; - block_idx_opt.set_gpu_block_index(0); - - BufHandle a("a", {1, M, N}, kFloat); - BufHandle b("b", {1}, kFloat); - VarHandle k("k", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - - std::vector block; - std::vector dims; - dims.push_back(ExprHandle(N).node()); - BufHandle c{alloc("c", dims, kFloat)}; - { - // alloc(c, 64); - AllocatePtr alloc = Allocate::make(c); - block.push_back(alloc); - } - - { - // for n in 0..64: // thread-idx - // c(n) = 0 - StorePtr store_cn_0 = Store::make(c, {n}, 0.f); - ForPtr loop_n1 = For::make(n, 0, N, store_cn_0, thread_idx_opt); - block.push_back(loop_n1); - } - - { - // for m in 0..128: - // for n in 0..64: // thread_idx - // c(n) = c(n) + a(k, m, n) - ExprHandle load_cn = Load::make(kFloat, c, {n}); - ExprHandle a_kmn = Load::make(a, {k * (M * N) + m * N + n}); - ExprHandle v_add = load_cn + a_kmn; - StorePtr store_cn_v = Store::make(c, {n}, v_add); - ForPtr loop_n2 = For::make(n, 0, N, store_cn_v, thread_idx_opt); - ForPtr loop_m1 = For::make(m, 0, M, loop_n2); - block.push_back(loop_m1); - } - - { - // b(k) = 0 - // for n in 0..64: // thread_idx - // b(k) = b(k) + c(n) - StorePtr store_bk_0 = b.store({k}, 0.f); - block.push_back(store_bk_0); - ExprHandle load_bk = b.load(k); - ExprHandle load_cn = Load::make(kFloat, c, {n}); - ExprHandle v_add = load_bk + load_cn; - StorePtr store_bk = b.store({k}, v_add); - ForPtr loop_n3 = For::make(n, 0, N, store_bk, thread_idx_opt); - block.push_back(loop_n3); - } - - { - // free(c) - FreePtr free_stmt = Free::make(c); - block.push_back(free_stmt); - } - - BlockPtr reduce_body = Block::make(block); - ForPtr loop_k1 = For::make(k, 0, 1, reduce_body, block_idx_opt); - - // TODO: check the generated code for correctness. - CudaCodeGen cuda_cg(loop_k1, a, b); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - // Check the c write is not masked, but the d write is. - const std::string& verification_pattern = - R"IR( -# CHECK: c_1 = 0 -# CHECK: for (int m = 0; m < 128 -# CHECK: c_1 = c_1 + -# CHECK: __syncthreads(); -# CHECK: if (threadIdx.x<1 -# CHECK: b[blockIdx.x] = -# CHECK: __syncthreads(); -# CHECK: atomicAdd(&b[blockIdx.x], c_1) -)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - PaddedBuffer a_v(1, M, N, "a_v"); - PaddedBuffer b_v(1, "b_v"); - PaddedBuffer b_ref(1, "b_ref"); - - b_ref(0) = 0; - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - int v = i + j; - a_v(0, i, j) = v; - b_ref(0) += v; - } - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, kTotalSize * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, a_v.data(), kTotalSize * sizeof(float), cudaMemcpyHostToDevice)); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, 1 * sizeof(float))); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK( - cudaMemcpy(b_v.data(), b_dev, 1 * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(b_v, b_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); -} - -TEST(Cuda, LocalMemReduce_1_CUDA) { - // This test does the following: - // for k in 0..1: // block-idx - // b(k) = 0 - // for n in 0..64: // thread-idx - // alloc(c, 1) - // c(0) = 0 - // for m in 0..128: - // c(0) = c(0) + a(k, m, n) - // b(k) = b(k) + c(0) - // free(c) - - const int M = 128; - const int N = 64; - const int kTotalSize = M * N; - LoopOptions thread_idx_opt; - thread_idx_opt.set_gpu_thread_index(0); - LoopOptions block_idx_opt; - block_idx_opt.set_gpu_block_index(0); - - BufHandle a("a", {1, M, N}, kFloat); - BufHandle b("b", {1}, kFloat); - VarHandle k("k", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - - BufHandle c{ - alloc("c", std::vector({alloc(1)}), kFloat)}; - std::vector block_k; - { - // b(k) = 0 - StorePtr store_bk_0 = b.store({k}, 0.f); - block_k.push_back(store_bk_0); - } - std::vector block_n; - { - // alloc(c, 1); - AllocatePtr alloc = Allocate::make(c); - block_n.push_back(alloc); - } - { - // c(0) = 0 - StorePtr store_c0_0 = Store::make(c, {0}, 0.f); - block_n.push_back(store_c0_0); - } - { - // for m in 0..128: - // c(0) = c(0) + a(k, m, n) - ExprHandle load_c0 = Load::make(kFloat, c, {0}); - ExprHandle a_kmn = a.load(k * (M * N) + m * N + n); - ExprHandle v_add = load_c0 + a_kmn; - StorePtr store_c0_v = Store::make(c, {0}, v_add); - ForPtr loop_m = For::make(m, 0, M, store_c0_v); - block_n.push_back(loop_m); - } - { - // b(k) = b(k) + c(0) - ExprHandle load_bk = b.load(k); - ExprHandle load_c0 = Load::make(kFloat, c, {0}); - ExprHandle v_add = load_bk + load_c0; - StorePtr store_bk = b.store({k}, v_add); - block_n.push_back(store_bk); - } - { - // free(c) - FreePtr free_stmt = Free::make(c); - block_n.push_back(free_stmt); - } - { - BlockPtr block_n_stmt = Block::make(block_n); - ForPtr for_n = For::make(n, 0, N, block_n_stmt, thread_idx_opt); - block_k.push_back(for_n); - } - BlockPtr block_k_stmt = Block::make(block_k); - ForPtr loop_k = For::make(k, 0, 1, block_k_stmt, block_idx_opt); - - CudaCodeGen cuda_cg(loop_k, a, b); - PaddedBuffer a_v(1, M, N, "a_v"); - PaddedBuffer b_v(1, "b_v"); - PaddedBuffer b_ref(1, "b_ref"); - - b_ref(0) = 0; - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - int v = i + j; - a_v(0, i, j) = v; - b_ref(0) += v; - } - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, kTotalSize * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, a_v.data(), kTotalSize * sizeof(float), cudaMemcpyHostToDevice)); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, 1 * sizeof(float))); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK( - cudaMemcpy(b_v.data(), b_dev, 1 * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(b_v, b_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); -} - -TEST(Cuda, HalfSupport_CUDA) { - auto half = ToDtype(); - BufHandle a("a", {4}, half); - Tensor b = Compute("b", {4}, [&](const VarHandle& i) { - return Cast::make(half, ExprHandle(2.0f) * a.load(i)); - }); - - Tensor c = Compute("c", {4}, [&](const VarHandle& i) { - return Cast::make(kFloat, Cast::make(half, ExprHandle(42)) + b.load(i)); - }); - - Tensor d = Compute("d", {4}, [&](const VarHandle& i) { - return Cast::make(half, c.load(i)); - }); - - LoopNest l({b, c, d}); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - CudaCodeGen cg(s, {a, b, c, d}); - - std::vector aData(4, 2.0f); - std::vector cData(4, 0.0f); - std::vector dData(4, 0.0f); - at::Half* aDev = nullptr; - at::Half* bDev = nullptr; - at::Half* cDev = nullptr; - at::Half* dDev = nullptr; - auto aSize = aData.size() * sizeof(aData[0]); - auto bSize = aData.size() * sizeof(aData[0]); - auto cSize = cData.size() * sizeof(float); - auto dSize = dData.size() * sizeof(dData[0]); - - C10_CUDA_CHECK(cudaMalloc(&aDev, aSize)); - C10_CUDA_CHECK(cudaMalloc(&bDev, bSize)); - C10_CUDA_CHECK(cudaMalloc(&cDev, cSize)); - C10_CUDA_CHECK(cudaMalloc(&dDev, dSize)); - C10_CUDA_CHECK(cudaMemcpy(aDev, aData.data(), aSize, cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy(cDev, cData.data(), cSize, cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy(dDev, dData.data(), dSize, cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cg.call({aDev, bDev, cDev, dDev}); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - C10_CUDA_CHECK(cudaMemcpy(aData.data(), aDev, aSize, cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy(cData.data(), cDev, cSize, cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy(dData.data(), dDev, dSize, cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - assertAllEqual(cData, 46.0f); - - C10_CUDA_CHECK(cudaFree(aDev)); - C10_CUDA_CHECK(cudaFree(bDev)); - C10_CUDA_CHECK(cudaFree(cDev)); - C10_CUDA_CHECK(cudaFree(dDev)); -} - -TEST(Cuda, HalfPropagation_CUDA) { - auto half = ToDtype(); - BufHandle a("a", {4}, half); - Tensor relu = Compute("relu", {4}, [&](const VarHandle& i) { - return Max::make(a.load(i), ExprHandle(alloc(0)), true); - }); - - LoopNest l({relu}); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - CudaCodeGen cg(s, {a, relu}); - - std::ostringstream oss; - oss << *cg.stmt(); - - // Check the types used by the Max are Float. - const std::string& verification_pattern = - R"IR( -# CHECK: for ( -# CHECK: float v = float(a[i]); -# CHECK: relu[i] = half(Max(v, 0.f -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector aData(4, 2.0f); - std::vector reluData(4, 0.0f); - at::Half* aDev = nullptr; - at::Half* reluDev = nullptr; - auto aSize = aData.size() * sizeof(aData[0]); - auto reluSize = reluData.size() * sizeof(reluData[0]); - - C10_CUDA_CHECK(cudaMalloc(&aDev, aSize)); - C10_CUDA_CHECK(cudaMalloc(&reluDev, reluSize)); - C10_CUDA_CHECK(cudaMemcpy(aDev, aData.data(), aSize, cudaMemcpyHostToDevice)); - C10_CUDA_CHECK( - cudaMemcpy(reluDev, reluData.data(), reluSize, cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cg.call({aDev, reluDev}); - C10_CUDA_CHECK( - cudaMemcpy(reluData.data(), reluDev, reluSize, cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - assertAllEqual(aData, reluData); - - C10_CUDA_CHECK(cudaFree(aDev)); - C10_CUDA_CHECK(cudaFree(reluDev)); -} - -TEST(Cuda, UnusedHalfArgument_CUDA) { - BufHandle a("a", {4}, kFloat); - auto half = ToDtype(); - BufHandle b("b", {4}, half); - Tensor relu = Compute("relu", {4}, [&](const VarHandle& i) { - return Max::make(a.load(i), ExprHandle(alloc(0)), true); - }); - - LoopNest l({relu}); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - CudaCodeGen cg(s, {a, b, relu}); - - std::ostringstream oss; - oss << *cg.stmt(); - - // Check the types used by the Max are Float. - const std::string& verification_pattern = - R"IR( -# CHECK: for ( -# CHECK: float v = a[i]; -# CHECK: relu[i] = Max(v, 0.f -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // Sanity Cbeck; - std::vector aData(4, 2.0f); - std::vector bData(4, 2.0f); - std::vector reluData(4, 0.0f); - at::Half* aDev = nullptr; - at::Half* bDev = nullptr; - at::Half* reluDev = nullptr; - auto aSize = aData.size() * sizeof(aData[0]); - auto bSize = bData.size() * sizeof(bData[0]); - auto reluSize = reluData.size() * sizeof(reluData[0]); - - C10_CUDA_CHECK(cudaMalloc(&aDev, aSize)); - C10_CUDA_CHECK(cudaMalloc(&bDev, bSize)); - C10_CUDA_CHECK(cudaMalloc(&reluDev, reluSize)); - C10_CUDA_CHECK(cudaMemcpy(aDev, aData.data(), aSize, cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy(bDev, bData.data(), bSize, cudaMemcpyHostToDevice)); - C10_CUDA_CHECK( - cudaMemcpy(reluDev, reluData.data(), reluSize, cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cg.call({aDev, bDev, reluDev}); - C10_CUDA_CHECK( - cudaMemcpy(reluData.data(), reluDev, reluSize, cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - assertAllEqual(aData, reluData); - - C10_CUDA_CHECK(cudaFree(aDev)); - C10_CUDA_CHECK(cudaFree(bDev)); - C10_CUDA_CHECK(cudaFree(reluDev)); -} - -TEST(Cuda, PrioritizeDependents_CUDA) { - BufHandle a("a", {10}, kFloat); - BufHandle b("b", {12}, kFloat); - BufHandle c("c", {12}, kFloat); - - LoopOptions block_idx_opt; - block_idx_opt.set_gpu_block_index(0); - - VarHandle i("i", kInt); - VarHandle j("j", kInt); - - /* - * for (const auto i : c10::irange(12)) { - * c[i] = (i < 10 ? a[i] + b[i] : b[i]); - * } - */ - ExprHandle load_a = a.load({i}); - ExprHandle load_b = b.load({i}); - ExprHandle cmp = CompareSelect::make(i, 10, CompareSelectOperation::kLT); - ExprHandle ite = IfThenElse::make(cmp, Add::make(load_a, load_b), load_b); - - ForPtr loop = - For::make(i, 0, 12, Block::make({c.store({i}, ite)}), block_idx_opt); - - CudaCodeGen cuda_cg(loop, a, b, c); - - PaddedBuffer a_v(10, "a_v"); - PaddedBuffer b_v(12, "b_v"); - PaddedBuffer c_v(12, "c_v"); - PaddedBuffer c_ref(12, "c_ref"); - - for (const auto i : c10::irange(10)) { - a_v(i) = i * 100; - b_v(i) = i; - c_v(i) = 0; - } - - for (const auto i : c10::irange(10, 12)) { - b_v(i) = i; - c_v(i) = 0; - } - - float* a_dev = nullptr; - float* b_dev = nullptr; - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, 10 * sizeof(float))); - C10_CUDA_CHECK(cudaMalloc(&b_dev, 12 * sizeof(float))); - C10_CUDA_CHECK(cudaMalloc(&c_dev, 12 * sizeof(float))); - - C10_CUDA_CHECK(cudaMemcpy( - a_dev, a_v.data(), 10 * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, b_v.data(), 12 * sizeof(float), cudaMemcpyHostToDevice)); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(a_dev, b_dev, c_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), c_dev, 12 * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - for (const auto i : c10::irange(12)) { - if (i < 10) { - c_ref(i) = i + i * 100; - } else { - c_ref(i) = i; - } - } - - ExpectAllNear(c_v, c_ref, 1e-5); -} - -/// Tests the case where there are two loops which have different extents bound -/// to the same block dimension. We must mask the smaller extent loop body. -TEST(Cuda, MaskBlockDim_CUDA) { - int A_SIZE = 100; - int B_SIZE = 50; - BufHandle a_buf("a", {A_SIZE}, kFloat); - BufHandle b_buf("b", {B_SIZE}, kFloat); - Tensor c = Compute( - "c", {A_SIZE}, [&](const VarHandle& i) { return a_buf.load(i) + 10; }); - Tensor d = Compute("d", {B_SIZE}, [&](const VarHandle& i) { - return a_buf.load(i) + b_buf.load(i); - }); - - LoopNest l({c, d}); - std::vector loops = l.getLoopStmtsFor(c); - loops[0]->set_gpu_block_index(0); - loops = l.getLoopStmtsFor(d); - loops[0]->set_gpu_block_index(0); - - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - // Check the c write is not masked, but the d write is. - const std::string& verification_pattern = - R"IR( -# CHECK-NOT: if (blockIdx -# CHECK: c[blockIdx.x] = -# CHECK: if (blockIdx.x<50 -# CHECK: d[blockIdx.x] =)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto blockExtents = cuda_cg.gpu_block_extents(); - auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], alloc(A_SIZE))); - ASSERT_TRUE(exprEquals(threadExtents[0], alloc(1))); - - // Sanity check that the kernel works. - PaddedBuffer a_v(A_SIZE); - PaddedBuffer b_v(B_SIZE); - PaddedBuffer c_v(A_SIZE); - PaddedBuffer d_v(B_SIZE); - - PaddedBuffer c_ref(A_SIZE); - PaddedBuffer d_ref(B_SIZE); - - for (const auto i : c10::irange(A_SIZE)) { - a_v(i) = (float)i; - c_ref(i) = (float)(i + 10); - } - - for (const auto i : c10::irange(B_SIZE)) { - b_v(i) = (float)(B_SIZE - i); - d_ref(i) = a_v(i) + b_v(i); - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, A_SIZE * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, B_SIZE * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, A_SIZE * sizeof(float))); - float* d_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&d_dev, B_SIZE * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, a_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, b_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - c_dev, c_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - d_dev, d_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, d_dev, a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), c_dev, A_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy( - d_v.data(), d_dev, B_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - ExpectAllNear(d_v, d_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); - C10_CUDA_CHECK(cudaFree(d_dev)); -} - -/// Tests the case with two loops, which have different extents that are bound -/// to the same thread dimension. This is the same as the above - the smaller -/// rank write should be masked. But this time we also need to syncthreads. -TEST(Cuda, MaskThreadDim_CUDA) { - int A_SIZE = 50; - int B_SIZE = 100; - BufHandle a_buf("a", {A_SIZE}, kFloat); - BufHandle b_buf("b", {B_SIZE}, kFloat); - Tensor c = Compute( - "c", {A_SIZE}, [&](const VarHandle& i) { return a_buf.load(i) + 10; }); - Tensor d = Compute("d", {B_SIZE}, [&](const VarHandle& i) { - return a_buf.load(i / 2) + b_buf.load(i); - }); - - LoopNest l({c, d}); - std::vector loops = l.getLoopStmtsFor(c); - loops[0]->set_gpu_thread_index(0); - loops = l.getLoopStmtsFor(d); - loops[0]->set_gpu_thread_index(0); - - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - // Check the c write is masked, but the d write is not. - const std::string& verification_pattern = - R"IR( -# CHECK: if (threadIdx.x<50 -# CHECK: c[threadIdx.x] = -# CHECK: __syncthreads(); -# CHECK-NOT: if (threadIdx.x -# CHECK: d[threadIdx.x] =)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto blockExtents = cuda_cg.gpu_block_extents(); - auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], alloc(1))); - ASSERT_TRUE(exprEquals(threadExtents[0], alloc(B_SIZE))); - - PaddedBuffer a_v(A_SIZE); - PaddedBuffer b_v(B_SIZE); - PaddedBuffer c_v(A_SIZE); - PaddedBuffer d_v(B_SIZE); - - PaddedBuffer c_ref(A_SIZE); - PaddedBuffer d_ref(B_SIZE); - - for (const auto i : c10::irange(A_SIZE)) { - a_v(i) = (float)i; - c_ref(i) = (float)(i + 10); - } - - for (const auto i : c10::irange(B_SIZE)) { - b_v(i) = (float)(B_SIZE - i); - d_ref(i) = a_v(i / 2) + b_v(i); - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, A_SIZE * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, B_SIZE * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, A_SIZE * sizeof(float))); - float* d_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&d_dev, B_SIZE * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, a_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, b_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - c_dev, c_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - d_dev, d_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, d_dev, a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), c_dev, A_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy( - d_v.data(), d_dev, B_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - ExpectAllNear(d_v, d_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); - C10_CUDA_CHECK(cudaFree(d_dev)); -} - -/// Tests the case where there are two loops, and each is bound to a different -/// block dimension. In this case all writes should be masked since they occur -/// in distinct dimensions. -// Note: this is an extremely dumb pattern which we should never see, but is a -// useful edge case to make sure we've got things covered. -TEST(Cuda, MaskMultiBlockDim_CUDA) { - int A_SIZE = 100; - int B_SIZE = 50; - BufHandle a_buf("a", {A_SIZE}, kFloat); - BufHandle b_buf("b", {B_SIZE}, kFloat); - Tensor c = Compute( - "c", {A_SIZE}, [&](const VarHandle& i) { return a_buf.load(i) + 10; }); - Tensor d = Compute("d", {B_SIZE}, [&](const VarHandle& i) { - return a_buf.load(i) + b_buf.load(i); - }); - - LoopNest l({c, d}); - std::vector loops = l.getLoopStmtsFor(c); - loops[0]->set_gpu_block_index(0); - loops = l.getLoopStmtsFor(d); - loops[0]->set_gpu_block_index(1); - - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - // Write to c should be masked against y, write to d against x. - const std::string& verification_pattern = - R"IR( -# CHECK: if (blockIdx.y<1 -# CHECK: c[blockIdx.x] = -# CHECK: if (blockIdx.x<1 -# CHECK: d[blockIdx.y] =)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto blockExtents = cuda_cg.gpu_block_extents(); - auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], alloc(A_SIZE))); - ASSERT_TRUE(exprEquals(blockExtents[1], alloc(B_SIZE))); - - PaddedBuffer a_v(A_SIZE); - PaddedBuffer b_v(B_SIZE); - PaddedBuffer c_v(A_SIZE); - PaddedBuffer d_v(B_SIZE); - - PaddedBuffer c_ref(A_SIZE); - PaddedBuffer d_ref(B_SIZE); - - for (const auto i : c10::irange(A_SIZE)) { - a_v(i) = (float)i; - c_ref(i) = (float)(i + 10); - } - - for (const auto i : c10::irange(B_SIZE)) { - b_v(i) = (float)(B_SIZE - i); - d_ref(i) = a_v(i) + b_v(i); - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, A_SIZE * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, B_SIZE * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, A_SIZE * sizeof(float))); - float* d_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&d_dev, B_SIZE * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, a_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, b_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - c_dev, c_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - d_dev, d_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, d_dev, a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), c_dev, A_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy( - d_v.data(), d_dev, B_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - ExpectAllNear(d_v, d_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); - C10_CUDA_CHECK(cudaFree(d_dev)); -} - -/// Tests the case where both the blockDim and threadDim are bound to different -/// loops. In this instance both stores should be masked since they are -/// distinct. -// Note: this is an extremely dumb pattern which we should never see, but is a -// useful edge case to make sure we've got things covered. -TEST(Cuda, MaskBlockAndThreadDim_CUDA) { - int A_SIZE = 100; - int B_SIZE = 50; - BufHandle a_buf("a", {A_SIZE}, kFloat); - BufHandle b_buf("b", {B_SIZE}, kFloat); - Tensor c = Compute( - "c", {A_SIZE}, [&](const VarHandle& i) { return a_buf.load(i) + 10; }); - Tensor d = Compute("d", {B_SIZE}, [&](const VarHandle& i) { - return a_buf.load(i) + b_buf.load(i); - }); - - LoopNest l({c, d}); - std::vector loops = l.getLoopStmtsFor(c); - loops[0]->set_gpu_block_index(0); - loops = l.getLoopStmtsFor(d); - loops[0]->set_gpu_thread_index(0); - - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - const std::string& verification_pattern = - R"IR( -# CHECK: if (threadIdx.x<1 -# CHECK: c[blockIdx.x] = -# CHECK: } -# CHECK: if (blockIdx.x<1 -# CHECK: d[threadIdx.x] =)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto blockExtents = cuda_cg.gpu_block_extents(); - auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], alloc(A_SIZE))); - ASSERT_TRUE(exprEquals(threadExtents[0], alloc(B_SIZE))); - - PaddedBuffer a_v(A_SIZE); - PaddedBuffer b_v(B_SIZE); - PaddedBuffer c_v(A_SIZE); - PaddedBuffer d_v(B_SIZE); - - PaddedBuffer c_ref(A_SIZE); - PaddedBuffer d_ref(B_SIZE); - - for (const auto i : c10::irange(A_SIZE)) { - a_v(i) = (float)i; - c_ref(i) = (float)(i + 10); - } - - for (const auto i : c10::irange(B_SIZE)) { - b_v(i) = (float)(B_SIZE - i); - d_ref(i) = a_v(i) + b_v(i); - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, A_SIZE * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, B_SIZE * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, A_SIZE * sizeof(float))); - float* d_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&d_dev, B_SIZE * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, a_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, b_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - c_dev, c_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - d_dev, d_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, d_dev, a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), c_dev, A_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy( - d_v.data(), d_dev, B_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - ExpectAllNear(d_v, d_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); - C10_CUDA_CHECK(cudaFree(d_dev)); -} - -/// Tests the case where the loopnest has two loops of depth two: each with the -/// outer loop bound to blockDim.x and the inner loop bound to threadDim.x. In -/// this case all writes with a rank smaller than the max should be masked. -TEST(Cuda, MaskMultiDim_CUDA) { - int OUTER_SIZE = 10; - int A_SIZE = 100; - int B_SIZE = 50; - BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat); - BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat); - Tensor c = Compute( - "C", {OUTER_SIZE, A_SIZE}, [&](const VarHandle& i, const VarHandle& j) { - return ExprHandle(2) * a_buf.load(i, j); - }); - Tensor d = Compute( - "D", {OUTER_SIZE, B_SIZE}, [&](const VarHandle& i, const VarHandle& j) { - return c.load(i, j * 2) + b_buf.load(i, j); - }); - - LoopNest l({c, d}); - std::vector loops = l.getLoopStmtsFor(c); - loops[0]->set_gpu_block_index(0); - loops[1]->set_gpu_thread_index(0); - loops = l.getLoopStmtsFor(d); - loops[0]->set_gpu_block_index(0); - loops[1]->set_gpu_thread_index(0); - - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - // The write to D should be masked, but not the write to C. - const std::string& verification_pattern = - R"IR( -# CHECK-NOT: if ( -# CHECK: C[threadIdx.x + 100 * blockIdx.x] = -# CHECK: __syncthreads(); -# CHECK: if (threadIdx.x<50 -# CHECK: D[threadIdx.x + 50 * blockIdx.x] =)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto blockExtents = cuda_cg.gpu_block_extents(); - auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], alloc(OUTER_SIZE))); - ASSERT_TRUE(exprEquals(threadExtents[0], alloc(A_SIZE))); - - PaddedBuffer a_v(OUTER_SIZE, A_SIZE); - PaddedBuffer b_v(OUTER_SIZE, B_SIZE); - PaddedBuffer c_v(OUTER_SIZE, A_SIZE); - PaddedBuffer d_v(OUTER_SIZE, B_SIZE); - - PaddedBuffer c_ref(OUTER_SIZE, A_SIZE); - PaddedBuffer d_ref(OUTER_SIZE, B_SIZE); - - for (const auto o : c10::irange(OUTER_SIZE)) { - for (const auto i : c10::irange(A_SIZE)) { - a_v(o, i) = (float)i; - c_ref(o, i) = (float)(i * 2); - } - } - - for (const auto o : c10::irange(OUTER_SIZE)) { - for (const auto i : c10::irange(B_SIZE)) { - b_v(o, i) = (float)(B_SIZE - i); - d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); - } - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_SIZE * A_SIZE * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_SIZE * B_SIZE * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_SIZE * A_SIZE * sizeof(float))); - float* d_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_SIZE * B_SIZE * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, - a_v.data(), - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, - b_v.data(), - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - c_dev, - c_v.data(), - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - d_dev, - d_v.data(), - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, d_dev, a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), - c_dev, - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy( - d_v.data(), - d_dev, - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - ExpectAllNear(d_v, d_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); - C10_CUDA_CHECK(cudaFree(d_dev)); -} - -// Tests the case where loop extents are symbolic and not known at compile time. -// In this case both stores must be masked against the extent of the other loop, -// in case it is larger. -TEST(Cuda, MaskMultiDimSymbolic_CUDA) { - VarHandle OUTER_SIZE("OUTER_SIZE", kLong); - VarHandle A_SIZE("A_SIZE", kLong); - VarHandle B_SIZE("B_SIZE", kLong); - BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat); - BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat); - Tensor c = Compute( - "C", {OUTER_SIZE, A_SIZE}, [&](const VarHandle& i, const VarHandle& j) { - return ExprHandle(2) * a_buf.load(i, j); - }); - Tensor d = Compute( - "D", {OUTER_SIZE, B_SIZE}, [&](const VarHandle& i, const VarHandle& j) { - return c.load(i, j * 2) + b_buf.load(i, j); - }); - - LoopNest l({c, d}); - std::vector loops = l.getLoopStmtsFor(c); - loops[0]->set_gpu_block_index(0); - loops[1]->set_gpu_thread_index(0); - loops = l.getLoopStmtsFor(d); - loops[0]->set_gpu_block_index(0); - loops[1]->set_gpu_thread_index(0); - - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, d, OUTER_SIZE, A_SIZE, B_SIZE, a_buf, b_buf); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - // Since we don't know which is bigger (A_SIZE or B_SIZE) we must mask both. - const std::string& verification_pattern = - R"IR( -# CHECK: if (threadIdx.x(A_SIZE.node(), B_SIZE.node(), true))); - - int64_t OUTER_EXTENT = 10; - int64_t A_EXTENT = 100; - int64_t B_EXTENT = 50; - - PaddedBuffer a_v(OUTER_EXTENT, A_EXTENT); - PaddedBuffer b_v(OUTER_EXTENT, B_EXTENT); - PaddedBuffer c_v(OUTER_EXTENT, A_EXTENT); - PaddedBuffer d_v(OUTER_EXTENT, B_EXTENT); - - PaddedBuffer c_ref(OUTER_EXTENT, A_EXTENT); - PaddedBuffer d_ref(OUTER_EXTENT, B_EXTENT); - - for (const auto o : c10::irange(OUTER_EXTENT)) { - for (const auto i : c10::irange(A_EXTENT)) { - a_v(o, i) = (float)i; - c_ref(o, i) = (float)(i * 2); - } - } - - for (const auto o : c10::irange(OUTER_EXTENT)) { - for (const auto i : c10::irange(B_EXTENT)) { - b_v(o, i) = (float)(B_EXTENT - i); - d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); - } - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_EXTENT * A_EXTENT * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_EXTENT * B_EXTENT * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_EXTENT * A_EXTENT * sizeof(float))); - float* d_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_EXTENT * B_EXTENT * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, - a_v.data(), - OUTER_EXTENT * A_EXTENT * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, - b_v.data(), - OUTER_EXTENT * B_EXTENT * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - c_dev, - c_v.data(), - OUTER_EXTENT * A_EXTENT * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - d_dev, - d_v.data(), - OUTER_EXTENT * B_EXTENT * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, d_dev, OUTER_EXTENT, A_EXTENT, B_EXTENT, a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), - c_dev, - OUTER_EXTENT * A_EXTENT * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy( - d_v.data(), - d_dev, - OUTER_EXTENT * B_EXTENT * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - ExpectAllNear(d_v, d_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); - C10_CUDA_CHECK(cudaFree(d_dev)); -} - -// Tests the case where two loops are fused at a common parent loop, which is -// bound to the block dimension. Internally the inner loops have different -// extents but are bound to the same thread dimension. The smaller loop should -// be masked. -TEST(Cuda, MaskCompoundInnerLoop_CUDA) { - int OUTER_SIZE = 10; - int A_SIZE = 100; - int B_SIZE = 50; - BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat); - BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat); - BufHandle c_buf("c", {OUTER_SIZE, A_SIZE}, kFloat); - BufHandle d_buf("d", {OUTER_SIZE, B_SIZE}, kFloat); - - // Can't build this using Compute and transforms yet. - LoopOptions blockBound; - blockBound.set_gpu_block_index(0); - LoopOptions threadBound; - threadBound.set_gpu_thread_index(0); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - - StmtPtr stmt = For::make( - i, - 0, - OUTER_SIZE, - Block::make( - {For::make( - j, - 0, - A_SIZE, - c_buf.store({i, j}, ExprHandle(2) * a_buf.load(i, j)), - threadBound), - For::make( - k, - 0, - B_SIZE, - d_buf.store({i, k}, c_buf.load(i, k * 2) + b_buf.load(i, k)), - threadBound)}), - blockBound); - - stmt = FlattenIndexes(stmt); - stmt = IRSimplifier::simplify(stmt); - - CudaCodeGen cuda_cg(stmt, a_buf, b_buf, c_buf, d_buf); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - // The write to D should be masked, but not the write to C. - const std::string& verification_pattern = - R"IR( -# CHECK-NOT: if ( -# CHECK: c[threadIdx.x + 100 * blockIdx.x] = -# CHECK: __syncthreads(); -# CHECK: if (threadIdx.x<50 -# CHECK: d[threadIdx.x + 50 * blockIdx.x] =)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto blockExtents = cuda_cg.gpu_block_extents(); - auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], alloc(OUTER_SIZE))); - ASSERT_TRUE(exprEquals(threadExtents[0], alloc(A_SIZE))); - - PaddedBuffer a_v(OUTER_SIZE, A_SIZE); - PaddedBuffer b_v(OUTER_SIZE, B_SIZE); - PaddedBuffer c_v(OUTER_SIZE, A_SIZE); - PaddedBuffer d_v(OUTER_SIZE, B_SIZE); - - PaddedBuffer c_ref(OUTER_SIZE, A_SIZE); - PaddedBuffer d_ref(OUTER_SIZE, B_SIZE); - - for (const auto o : c10::irange(OUTER_SIZE)) { - for (const auto i : c10::irange(A_SIZE)) { - a_v(o, i) = (float)i; - c_ref(o, i) = (float)(i * 2); - } - for (const auto i : c10::irange(B_SIZE)) { - b_v(o, i) = (float)(B_SIZE - i); - d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); - } - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_SIZE * A_SIZE * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_SIZE * B_SIZE * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_SIZE * A_SIZE * sizeof(float))); - float* d_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_SIZE * B_SIZE * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, - a_v.data(), - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, - b_v.data(), - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - c_dev, - c_v.data(), - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - d_dev, - d_v.data(), - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(a_dev, b_dev, c_dev, d_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), - c_dev, - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy( - d_v.data(), - d_dev, - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - ExpectAllNear(d_v, d_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); - C10_CUDA_CHECK(cudaFree(d_dev)); -} - -// Tests the case with two loops fused into a common parent, which is not bound -// to any block or thread dimension - however it's two inner loops are bound to -// the first thread dimensions. This should work just like the MaskThreadDim -// test where the bigger loop is unmasked but the smaller is masked. -TEST(Cuda, MaskInnerLoopOneBlock_CUDA) { - int OUTER_SIZE = 10; - int A_SIZE = 100; - int B_SIZE = 50; - BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat); - BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat); - BufHandle c_buf("c", {OUTER_SIZE, A_SIZE}, kFloat); - BufHandle d_buf("d", {OUTER_SIZE, B_SIZE}, kFloat); - - // Can't build this using Compute and transforms yet. - LoopOptions blockBound; - blockBound.set_gpu_block_index(0); - LoopOptions threadBound; - threadBound.set_gpu_thread_index(0); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - - StmtPtr stmt = For::make( - i, - 0, - OUTER_SIZE, - Block::make( - {For::make( - j, - 0, - A_SIZE, - c_buf.store({i, j}, ExprHandle(2) * a_buf.load(i, j)), - threadBound), - For::make( - k, - 0, - B_SIZE, - d_buf.store({i, k}, c_buf.load(i, k * 2) + b_buf.load(i, k)), - threadBound)})); - - stmt = FlattenIndexes(stmt); - stmt = IRSimplifier::simplify(stmt); - - CudaCodeGen cuda_cg(stmt, a_buf, b_buf, c_buf, d_buf); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - // The other loop remains the D write is masked. - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i = 0; i < 10 -# CHECK-NOT: if ( -# CHECK: c[threadIdx.x + 100 * i] = -# CHECK: __syncthreads(); -# CHECK: if (threadIdx.x<50 -# CHECK: d[threadIdx.x + 50 * i] =)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto blockExtents = cuda_cg.gpu_block_extents(); - auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], alloc(1))); - ASSERT_TRUE(exprEquals(threadExtents[0], alloc(A_SIZE))); - - PaddedBuffer a_v(OUTER_SIZE, A_SIZE); - PaddedBuffer b_v(OUTER_SIZE, B_SIZE); - PaddedBuffer c_v(OUTER_SIZE, A_SIZE); - PaddedBuffer d_v(OUTER_SIZE, B_SIZE); - - PaddedBuffer c_ref(OUTER_SIZE, A_SIZE); - PaddedBuffer d_ref(OUTER_SIZE, B_SIZE); - - for (const auto o : c10::irange(OUTER_SIZE)) { - for (const auto i : c10::irange(A_SIZE)) { - a_v(o, i) = (float)i; - c_ref(o, i) = (float)(i * 2); - } - for (const auto i : c10::irange(B_SIZE)) { - b_v(o, i) = (float)(B_SIZE - i); - d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); - } - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_SIZE * A_SIZE * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_SIZE * B_SIZE * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_SIZE * A_SIZE * sizeof(float))); - float* d_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_SIZE * B_SIZE * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, - a_v.data(), - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, - b_v.data(), - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - c_dev, - c_v.data(), - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - d_dev, - d_v.data(), - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(a_dev, b_dev, c_dev, d_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), - c_dev, - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy( - d_v.data(), - d_dev, - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - ExpectAllNear(d_v, d_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); - C10_CUDA_CHECK(cudaFree(d_dev)); -} - -// Tests the case with two loop nests, each of which bound to the same block -// size, but with internal loops bound to different thread rank (ie x and y). In -// this case both bodies must be masked against the other dimension being > 0. -// Note: this is a bit degenerate no one would actually write this for perf. -TEST(Cuda, MaskMultiDimMultiAxis_CUDA) { - int OUTER_SIZE = 10; - int A_SIZE = 30; - int B_SIZE = 15; - BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat); - BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat); - Tensor c = Compute( - "C", {OUTER_SIZE, A_SIZE}, [&](const VarHandle& i, const VarHandle& j) { - return ExprHandle(2) * a_buf.load(i, j); - }); - Tensor d = Compute( - "D", {OUTER_SIZE, B_SIZE}, [&](const VarHandle& i, const VarHandle& j) { - return c.load(i, j * 2) + b_buf.load(i, j); - }); - - LoopNest l({c, d}); - std::vector loops = l.getLoopStmtsFor(c); - loops[0]->set_gpu_block_index(0); - loops[1]->set_gpu_thread_index(0); - loops = l.getLoopStmtsFor(d); - loops[0]->set_gpu_block_index(0); - loops[1]->set_gpu_thread_index(1); - - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - // Both stores masked against the other thread dim < 1. - const std::string& verification_pattern = - R"IR( -# CHECK: if (threadIdx.y<1 -# CHECK: C[threadIdx.x + 30 * blockIdx.x] = -# CHECK: __syncthreads(); -# CHECK: if (threadIdx.x<1 -# CHECK: D[threadIdx.y + 15 * blockIdx.x] =)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto blockExtents = cuda_cg.gpu_block_extents(); - auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], alloc(OUTER_SIZE))); - ASSERT_TRUE(exprEquals(threadExtents[0], alloc(A_SIZE))); - - PaddedBuffer a_v(OUTER_SIZE, A_SIZE); - PaddedBuffer b_v(OUTER_SIZE, B_SIZE); - PaddedBuffer c_v(OUTER_SIZE, A_SIZE); - PaddedBuffer d_v(OUTER_SIZE, B_SIZE); - - PaddedBuffer c_ref(OUTER_SIZE, A_SIZE); - PaddedBuffer d_ref(OUTER_SIZE, B_SIZE); - - for (const auto o : c10::irange(OUTER_SIZE)) { - for (const auto i : c10::irange(A_SIZE)) { - a_v(o, i) = (float)i; - c_ref(o, i) = (float)(i * 2); - } - } - - for (const auto o : c10::irange(OUTER_SIZE)) { - for (const auto i : c10::irange(B_SIZE)) { - b_v(o, i) = (float)(B_SIZE - i); - d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); - } - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_SIZE * A_SIZE * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_SIZE * B_SIZE * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_SIZE * A_SIZE * sizeof(float))); - float* d_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_SIZE * B_SIZE * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, - a_v.data(), - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, - b_v.data(), - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - c_dev, - c_v.data(), - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - d_dev, - d_v.data(), - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, d_dev, a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), - c_dev, - OUTER_SIZE * A_SIZE * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy( - d_v.data(), - d_dev, - OUTER_SIZE * B_SIZE * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - ExpectAllNear(d_v, d_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); - C10_CUDA_CHECK(cudaFree(d_dev)); -} - -// Tests the case with two loop nests, each bound to both Block and Thread but -// the second loop is smaller in both cases - the second store must be masked -// for both the block and thread dimension. -TEST(Cuda, MaskMultiDimMultiLevel_CUDA) { - int OUTER_A_SIZE = 10; - int OUTER_B_SIZE = 5; - int A_SIZE = 30; - int B_SIZE = 15; - BufHandle a_buf("a", {OUTER_A_SIZE, A_SIZE}, kFloat); - BufHandle b_buf("b", {OUTER_B_SIZE, B_SIZE}, kFloat); - Tensor c = Compute( - "C", {OUTER_A_SIZE, A_SIZE}, [&](const VarHandle& i, const VarHandle& j) { - return ExprHandle(2) * a_buf.load(i, j); - }); - Tensor d = Compute( - "D", {OUTER_B_SIZE, B_SIZE}, [&](const VarHandle& i, const VarHandle& j) { - return c.load(i, j * 2) + b_buf.load(i, j); - }); - - LoopNest l({c, d}); - std::vector loops = l.getLoopStmtsFor(c); - loops[0]->set_gpu_block_index(0); - loops[1]->set_gpu_thread_index(0); - loops = l.getLoopStmtsFor(d); - loops[0]->set_gpu_block_index(0); - loops[1]->set_gpu_thread_index(0); - - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); - - std::ostringstream oss; - oss << *cuda_cg.stmt(); - - // The write to D should be masked twice, but not the write to C. - const std::string& verification_pattern = - R"IR( -# CHECK-NOT: if ( -# CHECK: C[threadIdx.x + 30 * blockIdx.x] = -# CHECK: __syncthreads(); -# CHECK: if (blockIdx.x<5 -# CHECK: if (threadIdx.x<15 -# CHECK: D[threadIdx.x + 15 * blockIdx.x] =)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto blockExtents = cuda_cg.gpu_block_extents(); - auto threadExtents = cuda_cg.gpu_thread_extents(); - ASSERT_TRUE(exprEquals(blockExtents[0], alloc(OUTER_A_SIZE))); - ASSERT_TRUE(exprEquals(threadExtents[0], alloc(A_SIZE))); - - PaddedBuffer a_v(OUTER_A_SIZE, A_SIZE); - PaddedBuffer b_v(OUTER_B_SIZE, B_SIZE); - PaddedBuffer c_v(OUTER_A_SIZE, A_SIZE); - PaddedBuffer d_v(OUTER_B_SIZE, B_SIZE); - - PaddedBuffer c_ref(OUTER_A_SIZE, A_SIZE); - PaddedBuffer d_ref(OUTER_B_SIZE, B_SIZE); - - for (const auto o : c10::irange(OUTER_A_SIZE)) { - for (const auto i : c10::irange(A_SIZE)) { - a_v(o, i) = (float)i; - c_ref(o, i) = (float)(i * 2); - } - } - - for (const auto o : c10::irange(OUTER_B_SIZE)) { - for (const auto i : c10::irange(B_SIZE)) { - b_v(o, i) = (float)(B_SIZE - i); - d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); - } - } - - float* a_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_A_SIZE * A_SIZE * sizeof(float))); - float* b_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_B_SIZE * B_SIZE * sizeof(float))); - float* c_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_A_SIZE * A_SIZE * sizeof(float))); - float* d_dev = nullptr; - C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_B_SIZE * B_SIZE * sizeof(float))); - C10_CUDA_CHECK(cudaMemcpy( - a_dev, - a_v.data(), - OUTER_A_SIZE * A_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - b_dev, - b_v.data(), - OUTER_B_SIZE * B_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - c_dev, - c_v.data(), - OUTER_A_SIZE * A_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaMemcpy( - d_dev, - d_v.data(), - OUTER_B_SIZE * B_SIZE * sizeof(float), - cudaMemcpyHostToDevice)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - cuda_cg(c_dev, d_dev, a_dev, b_dev); - - C10_CUDA_CHECK(cudaDeviceSynchronize()); - C10_CUDA_CHECK(cudaMemcpy( - c_v.data(), - c_dev, - OUTER_A_SIZE * A_SIZE * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaMemcpy( - d_v.data(), - d_dev, - OUTER_B_SIZE * B_SIZE * sizeof(float), - cudaMemcpyDeviceToHost)); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - ExpectAllNear(c_v, c_ref, 1e-5); - ExpectAllNear(d_v, d_ref, 1e-5); - - C10_CUDA_CHECK(cudaFree(a_dev)); - C10_CUDA_CHECK(cudaFree(b_dev)); - C10_CUDA_CHECK(cudaFree(c_dev)); - C10_CUDA_CHECK(cudaFree(d_dev)); -} - -} // namespace jit -} // namespace torch - -#endif diff --git a/test/cpp/tensorexpr/test_dynamic_shapes.cpp b/test/cpp/tensorexpr/test_dynamic_shapes.cpp deleted file mode 100644 index 07b9872fb832..000000000000 --- a/test/cpp/tensorexpr/test_dynamic_shapes.cpp +++ /dev/null @@ -1,701 +0,0 @@ -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { - -using namespace torch::indexing; -using namespace torch::jit::tensorexpr; - -TEST(DynamicShapes, SimpleGraph) { -#ifdef TORCH_ENABLE_LLVM - std::shared_ptr graph = std::make_shared(); - const auto graph_string = R"IR( - graph(%x : Tensor, - %SS_2 : int, - %SS_3 : int): - %3 : Tensor = aten::tanh(%x) - %4 : Tensor = aten::erf(%3) - return (%4))IR"; - torch::jit::parseIR(graph_string, graph.get()); - - auto x_inp = graph->inputs()[0]; - auto x_type = TensorType::create(at::rand({10, 5})); - std::vector x_sym_dims( - {c10::ShapeSymbol::newSymbol(), c10::ShapeSymbol::newSymbol()}); - auto x_sym_type = x_type->withSymbolicShapes(x_sym_dims); - graph->inputs().at(0)->setType(x_sym_type); - for (const auto n : graph->nodes()) { - n->output()->setType(x_sym_type); - } - - // Graph with symbolic shapes: - // - // graph(%x : Float(SS(-2), SS(-3)), - // %SS_2 : int, - // %SS_3 : int): - // %3 : Float(SS(-2), SS(-3)) = aten::tanh(%x) - // %4 : Float(SS(-2), SS(-3)) = aten::erf(%3) - // return (%4) - - std::vector input_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[x_inp] = input_desc; - symbolic_strides[graph->outputs().at(0)] = input_desc; - std::vector symbolic_shape_inputs = c10::fmap( - x_sym_dims, - [](const c10::ShapeSymbol& shapeSym) { return shapeSym.value(); }); - - TensorExprKernel kernel( - graph, {}, symbolic_shape_inputs, false, symbolic_strides); - // Run with the same static dims as the one we initialized the graph with. - { - auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::erf(at::tanh(a)); - - std::vector stack = fmap(std::vector({a})); - stack.push_back(10); - stack.push_back(5); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - } - - // Run with inputs having different dims. - { - auto a = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::erf(at::tanh(a)); - - std::vector stack = fmap(std::vector({a})); - stack.push_back(50); - stack.push_back(100); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - } -#endif -} - -TEST(DynamicShapes, GraphWith2InputsSameDims) { -#ifdef TORCH_ENABLE_LLVM - // The two inputs in this graph must have the same dims. - std::shared_ptr graph = std::make_shared(); - const auto graph_string = R"IR( - graph(%x : Tensor, - %y : Tensor, - %SS_2 : int, - %SS_3 : int): - %3 : Tensor = aten::tanh(%x) - %4 : Tensor = aten::erf(%3) - %5 : Tensor = aten::mul(%4, %y) - return (%5))IR"; - torch::jit::parseIR(graph_string, graph.get()); - - auto x_inp = graph->inputs()[0]; - auto y_inp = graph->inputs()[1]; - auto x_type = TensorType::create(at::rand({10, 5})); - std::vector x_sym_dims( - {c10::ShapeSymbol::newSymbol(), c10::ShapeSymbol::newSymbol()}); - auto x_sym_type = x_type->withSymbolicShapes(x_sym_dims); - graph->inputs().at(0)->setType(x_sym_type); - graph->inputs().at(1)->setType(x_sym_type); - for (const auto n : graph->nodes()) { - n->output()->setType(x_sym_type); - } - - // Graph with symbolic shapes: - // - // graph(%x : Float(SS(-4), SS(-5)), - // %y : Float(SS(-4), SS(-5)), - // %SS_2 : int, - // %SS_3 : int): - // %4 : Float(SS(-4), SS(-5)) = aten::tanh(%x) - // %5 : Float(SS(-4), SS(-5)) = aten::erf(%4) - // %6 : Float(SS(-4), SS(-5)) = aten::mul(%5, %y) - // return (%6) - - std::vector symbolic_shape_inputs = c10::fmap( - x_sym_dims, - [](const c10::ShapeSymbol& shapeSym) { return shapeSym.value(); }); - - std::vector input_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[x_inp] = input_desc; - symbolic_strides[y_inp] = input_desc; - symbolic_strides[graph->outputs().at(0)] = input_desc; - - TensorExprKernel kernel( - graph, {}, symbolic_shape_inputs, false, symbolic_strides); - - // Run with the same static dims as the one we initialized the graph with. - { - auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto b = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::mul(at::erf(at::tanh(a)), b); - - std::vector stack = fmap(std::vector({a, b})); - stack.push_back(10); - stack.push_back(5); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - } - - // Run with inputs having different dims. - { - auto a = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto b = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::mul(at::erf(at::tanh(a)), b); - - std::vector stack = fmap(std::vector({a, b})); - stack.push_back(50); - stack.push_back(100); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - } -#endif -} - -TEST(DynamicShapes, GraphWith2InputsAndBroadcast) { -#ifdef TORCH_ENABLE_LLVM - // The second input to the graph has a dim of size 1 which should be - // broadcasted in the at::mul op. - std::shared_ptr graph = std::make_shared(); - const auto graph_string = R"IR( - graph(%x : Float(10, 5, requires_grad=0, device=cpu), - %y : Float(1, 5, requires_grad=0, device=cpu), - %SS_2 : int, - %SS_3 : int): - %3 : Tensor = aten::tanh(%x) - %4 : Tensor = aten::erf(%3) - %5 : Tensor = aten::mul(%4, %y) - return (%5))IR"; - torch::jit::parseIR(graph_string, graph.get()); - - auto x_inp = graph->inputs()[0]; - auto y_inp = graph->inputs()[1]; - auto x_type = TensorType::create(at::rand({10, 5})); - auto y_type = TensorType::create(at::rand({1, 5})); - auto x_dim0_sym = c10::ShapeSymbol::newSymbol(); - auto x_dim1_sym = c10::ShapeSymbol::newSymbol(); - auto x_sym_type = x_type->withSymbolicShapes( - std::vector({x_dim0_sym, x_dim1_sym})); - auto y_sym_type = y_type->withSymbolicShapes(std::vector( - {c10::ShapeSymbol::fromStaticSize(1), x_dim1_sym})); - graph->inputs().at(0)->setType(x_sym_type); - graph->inputs().at(1)->setType(y_sym_type); - for (const auto n : graph->nodes()) { - n->output()->setType(x_sym_type); - } - - // Graph with symbolic shapes: - // - // graph(%x : Float(SS(-6), SS(-7)), - // %y : Float(1, SS(-7)), - // %SS_2 : int, - // %SS_3 : int): - // %4 : Float(SS(-6), SS(-7)) = aten::tanh(%x) - // %5 : Float(SS(-6), SS(-7)) = aten::erf(%4) - // %6 : Float(SS(-6), SS(-7)) = aten::mul(%5, %y) - // return (%6) - - std::vector symbolic_shape_inputs( - {x_dim0_sym.value(), x_dim1_sym.value()}); - - std::vector input_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[x_inp] = input_desc; - symbolic_strides[y_inp] = input_desc; - symbolic_strides[graph->outputs().at(0)] = input_desc; - - TensorExprKernel kernel( - graph, {}, symbolic_shape_inputs, false, symbolic_strides); - - // Run with the same static dims as the one we initialized the graph with. - { - auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto b = at::rand({1, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::mul(at::erf(at::tanh(a)), b); - - std::vector stack = fmap(std::vector({a, b})); - stack.push_back(10); - stack.push_back(5); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - } - - // Run with inputs having different dims. - { - auto a = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto b = at::rand({1, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::mul(at::erf(at::tanh(a)), b); - - std::vector stack = fmap(std::vector({a, b})); - stack.push_back(50); - stack.push_back(100); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - } -#endif -} - -TEST(DynamicShapes, GraphWithPartiallySymbolicOutput) { -#ifdef TORCH_ENABLE_LLVM - // The second input to the graph has a dim of size 1 which should be - // broadcasted in the at::mul op. - std::shared_ptr graph = std::make_shared(); - const auto graph_string = R"IR( - graph(%x : Float(1, 5, requires_grad=0, device=cpu), - %y : Float(1, 5, requires_grad=0, device=cpu), - %SS_2 : int): - %4 : Tensor = aten::tanh(%x) - %5 : Tensor = aten::mul(%4, %y) - return (%5))IR"; - torch::jit::parseIR(graph_string, graph.get()); - - auto x_inp = graph->inputs()[0]; - auto y_inp = graph->inputs()[1]; - auto x_type = TensorType::create(at::rand({1, 5})); - auto x_dim1_sym = c10::ShapeSymbol::newSymbol(); - auto x_sym_type = x_type->withSymbolicShapes(std::vector( - {c10::ShapeSymbol::fromStaticSize(1), x_dim1_sym})); - graph->inputs().at(0)->setType(x_sym_type); - graph->inputs().at(1)->setType(x_sym_type); - for (const auto n : graph->nodes()) { - n->output()->setType(x_sym_type); - } - - // Graph with symbolic shapes: - // - // graph(%x : Float(1, SS(-2)), - // %y : Float(1, SS(-2)), - // %SS_2 : int): - // %3 : Float(1, SS(-2)) = aten::tanh(%x) - // %4 : Float(1, SS(-2)) = aten::mul(%3, %y) - // return (%4) - - std::vector symbolic_shape_inputs({x_dim1_sym.value()}); - - std::vector input_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[x_inp] = input_desc; - symbolic_strides[y_inp] = input_desc; - symbolic_strides[graph->outputs().at(0)] = input_desc; - - TensorExprKernel kernel( - graph, {}, symbolic_shape_inputs, false, symbolic_strides); - - // Run with the same static dims as the one we initialized the graph with. - { - auto a = at::rand({1, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto b = at::rand({1, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::mul(at::tanh(a), b); - - std::vector stack = fmap(std::vector({a, b})); - stack.push_back(5); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - } - - // Run with inputs having different dims. - { - auto a = at::rand({1, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto b = at::rand({1, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::mul(at::tanh(a), b); - - std::vector stack = fmap(std::vector({a, b})); - stack.push_back(100); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - } -#endif -} - -TEST(DynamicShapes, GraphWithSymbolicStrides) { -#ifdef TORCH_ENABLE_LLVM - std::shared_ptr graph = std::make_shared(); - const auto graph_string = R"IR( - graph(%0 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu), - %1 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu), - %SS_3 : int, - %SS_2 : int): - %15 : int = prim::Constant[value=1]() - %21 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu) = aten::add(%0, %1, %15) - %22 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu) = aten::mul(%21, %0) - return (%22))IR"; - parseIR(graph_string, &*graph); - - std::vector input_desc = { - torch::jit::StrideInput::S_AS_ARG, torch::jit::StrideInput::S_ONE}; - std::vector output_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[graph->inputs().at(0)] = input_desc; - symbolic_strides[graph->inputs().at(1)] = input_desc; - symbolic_strides[graph->outputs().at(0)] = output_desc; - std::vector symbolic_shape_inputs = {-3, -2}; - TensorExprKernel k(graph, {}, symbolic_shape_inputs, false, symbolic_strides); - - { - auto x0 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto x1 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::mul(at::add(x0, x1, 1), x0); - - std::vector inputs = {x0, x1}; - std::vector stack = at::fmap(inputs); - stack.push_back(32); - stack.push_back(10); - k.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - } - - { - auto x0 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto x1 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto out = - at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::mul(at::add(x0, x1, 1), x0); - - std::vector inputs = {out, x0, x1}; - std::vector stack = at::fmap(inputs); - stack.push_back(32); - stack.push_back(10); - k.runWithAllocatedOutputs(stack); - - ASSERT_TRUE(at::allclose(out, ref)); - } -#endif -} - -TEST(DynamicShapes, GraphWithCatAndBroadcast) { -#ifdef TORCH_ENABLE_LLVM - std::shared_ptr graph = std::make_shared(); - const auto graph_string = R"IR( - graph(%x : Float(10, 5, requires_grad=0, device=cpu), - %y : Float(4, 5, requires_grad=0, device=cpu), - %z : Float(1, 1, requires_grad=0, device=cpu), - %SS_2 : int, - %SS_3 : int, - %SS_4 : int, - %SS_5 : int): - %11 : int = prim::Constant[value=0]() - %3 : Tensor = aten::tanh(%x) - %out1 : Tensor = aten::erf(%3) - %out2 : Tensor = aten::relu(%y) - %10 : Tensor[] = prim::ListConstruct(%out1, %out2) - %25 : Tensor = aten::cat(%10, %11) - %28 : Tensor = aten::hardswish(%25) - %29 : Tensor = aten::mul(%28, %z) - return (%29))IR"; - torch::jit::parseIR(graph_string, graph.get()); - - auto x_inp = graph->inputs()[0]; - auto y_inp = graph->inputs()[1]; - auto z_inp = graph->inputs()[2]; - auto x_type = TensorType::create(at::rand({10, 5})); - auto y_type = TensorType::create(at::rand({4, 5})); - auto z_type = TensorType::create(at::rand({1, 1})); - auto x_dim0_sym = c10::ShapeSymbol::newSymbol(); - auto x_dim1_sym = c10::ShapeSymbol::newSymbol(); - auto x_sym_type = x_type->withSymbolicShapes( - std::vector({x_dim0_sym, x_dim1_sym})); - auto y_dim0_sym = c10::ShapeSymbol::newSymbol(); - auto y_sym_type = y_type->withSymbolicShapes( - std::vector({y_dim0_sym, x_dim1_sym})); - graph->inputs().at(0)->setType(x_sym_type); - graph->inputs().at(1)->setType(y_sym_type); - auto cat_dim0_sym = c10::ShapeSymbol::newSymbol(); - auto cat_out_type = x_type->withSymbolicShapes( - std::vector({cat_dim0_sym, x_dim1_sym})); - auto nodeIt = graph->nodes().begin(); - ++nodeIt; - nodeIt->output()->setType(x_sym_type); // aten::tanh - ++nodeIt; - nodeIt->output()->setType(x_sym_type); // aten::erf - ++nodeIt; - nodeIt->output()->setType(y_sym_type); // aten::relu - ++nodeIt; - ++nodeIt; - nodeIt->output()->setType(cat_out_type); // aten::cat - ++nodeIt; - nodeIt->output()->setType(cat_out_type); // aten::hardswish - ++nodeIt; - nodeIt->output()->setType(cat_out_type); // aten::mul - - // Graph with symbolic shapes: - // - // graph(%x : Float(SS(-2), SS(-3)), - // %y : Float(SS(-4), SS(-3)), - // %z : Float(1, 1), - // %SS_2 : int, - // %SS_3 : int, - // %SS_4 : int, - // %SS_5 : int): - // %7 : int = prim::Constant[value=0]() - // %8 : Float(SS(-2), SS(-3)) = aten::tanh(%x) - // %9 : Float(SS(-2), SS(-3)) = aten::erf(%8) - // %10 : Float(SS(-4), SS(-3)) = aten::relu(%y) - // %11 : Tensor[] = prim::ListConstruct(%9, %10) - // %12 : Float(SS(-5), SS(-3)) = aten::cat(%11, %7) - // %13 : Float(SS(-5), SS(-3)) = aten::hardswish(%12) - // %14 : Float(SS(-5), SS(-3)) = aten::mul(%13, %z) - // return (%14) - - std::vector symbolic_shape_inputs( - {x_dim0_sym.value(), - x_dim1_sym.value(), - y_dim0_sym.value(), - cat_dim0_sym.value()}); - - std::vector input_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[x_inp] = input_desc; - symbolic_strides[y_inp] = input_desc; - symbolic_strides[z_inp] = input_desc; - symbolic_strides[graph->outputs().at(0)] = input_desc; - - TensorExprKernel kernel( - graph, {}, symbolic_shape_inputs, false, symbolic_strides); - - auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto b = at::rand({4, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto c = at::rand({1, 1}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::mul( - at::hardswish(at::cat({at::erf(at::tanh(a)), at::relu(b)}, 0)), c); - - std::vector stack = fmap(std::vector({a, b, c})); - stack.push_back(10); - stack.push_back(5); - stack.push_back(4); - stack.push_back(14); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); -#endif -} - -TEST(DynamicShapes, GraphFromModel) { -#ifdef TORCH_ENABLE_LLVM - std::shared_ptr graph = std::make_shared(); - const auto graph_string = R"IR( - graph(%0 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu), - %1 : Float(SS(-2), SS(-4), requires_grad=0, device=cpu), - %2 : Float(SS(-2), SS(-5), requires_grad=0, device=cpu), - %input.4 : Long(SS(-2), SS(-6), requires_grad=0, device=cpu), - %4 : Float(SS(-7), requires_grad=0, device=cpu), - %5 : Float(SS(-7), requires_grad=0, device=cpu), - %SS_10 : int, - %SS_9 : int, - %SS_8 : int, - %SS_7 : int, - %SS_6 : int, - %SS_5 : int, - %SS_4 : int, - %SS_3 : int, - %SS_2 : int): - %15 : int = prim::Constant[value=1]() - %16 : bool = prim::Constant[value=0]() - %17 : int = prim::Constant[value=6]() - %18 : Float(SS(-2), SS(-6), strides=[139, 1], requires_grad=0, device=cpu) = aten::to(%input.4, %17, %16, %16) - %19 : Tensor[] = prim::ListConstruct(%0, %1, %18, %2) - %20 : Float(SS(-2), SS(-8), strides=[261, 1], requires_grad=0, device=cpu) = aten::cat(%19, %15) - %21 : Float(SS(-2), SS(-9), strides=[261, 1], requires_grad=0, device=cpu) = aten::add(%20, %5, %15) - %22 : Float(SS(-2), SS(-10), requires_grad=0, device=cpu) = aten::mul(%21, %4) - return (%22))IR"; - parseIR(graph_string, &*graph); - - std::vector input_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[graph->inputs().at(0)] = input_desc; - symbolic_strides[graph->inputs().at(1)] = input_desc; - symbolic_strides[graph->inputs().at(2)] = input_desc; - symbolic_strides[graph->inputs().at(3)] = input_desc; - symbolic_strides[graph->inputs().at(4)] = input_desc; - symbolic_strides[graph->inputs().at(5)] = input_desc; - symbolic_strides[graph->outputs().at(0)] = input_desc; - std::vector symbolic_shape_inputs = { - -10, -9, -8, -7, -6, -5, -4, -3, -2}; - TensorExprKernel k(graph, {}, symbolic_shape_inputs, false, symbolic_strides); - - int64_t i2 = 10; - int64_t i3 = 32; - int64_t i4 = 19; - int64_t i5 = 71; - int64_t i6 = 139; - int64_t i7 = 261; - int64_t i8 = 261; - int64_t i9 = 261; - int64_t i10 = 261; - auto x0 = at::rand({i2, i3}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto x1 = at::rand({i2, i4}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto x2 = at::rand({i2, i5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto x3 = at::ones({i2, i6}, at::TensorOptions(at::kCPU).dtype(at::kLong)); - auto x4 = at::rand({i7}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto x5 = at::rand({i8}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto ref = at::mul(at::add(at::cat({x0, x1, x3, x2}, 1), x5), x4); - - { - std::vector inputs = {x0, x1, x2, x3, x4, x5}; - std::vector stack = at::fmap(inputs); - stack.emplace_back(i10); - stack.emplace_back(i9); - stack.emplace_back(i8); - stack.emplace_back(i7); - stack.emplace_back(i6); - stack.emplace_back(i5); - stack.emplace_back(i4); - stack.emplace_back(i3); - stack.emplace_back(i2); - k.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - } - - { - auto out = - at::rand({i2, i10}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - std::vector inputs = {out, x0, x1, x2, x3, x4, x5}; - std::vector stack = at::fmap(inputs); - stack.emplace_back(i10); - stack.emplace_back(i9); - stack.emplace_back(i8); - stack.emplace_back(i7); - stack.emplace_back(i6); - stack.emplace_back(i5); - stack.emplace_back(i4); - stack.emplace_back(i3); - stack.emplace_back(i2); - k.runWithAllocatedOutputs(stack); - - ASSERT_TRUE(at::allclose(out, ref)); - } -#endif -} - -TEST(DynamicShapes, MultiThreadedExecution) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_template = R"IR( - graph(%x : Float(SS(-2), SS(-3), requires_grad=0, device=${device}), - %y : Float(SS(-2), SS(-3), requires_grad=0, device=${device}), - %SS_2 : int, - %SS_3 : int): - %3 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::tanh(%x) - %4 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::erf(%3) - %5 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::mul(%4, %y) - return (%5))IR"; - for (bool use_cuda : {false, true}) { - if (!torch::cuda::is_available() && use_cuda) { - continue; - } - auto device = use_cuda ? at::kCUDA : at::kCPU; - at::jit::TemplateEnv env; - env.s("device", use_cuda ? "cuda:0" : "cpu"); - const auto graph_string = format(graph_template, env); - std::shared_ptr graph = std::make_shared(); - torch::jit::parseIR(graph_string, graph.get()); - - std::vector symbolic_shape_inputs = {-2, -3}; - - std::vector input_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[graph->inputs().at(0)] = input_desc; - symbolic_strides[graph->inputs().at(1)] = input_desc; - symbolic_strides[graph->outputs().at(0)] = input_desc; - - TensorExprKernel kernel( - graph, {}, symbolic_shape_inputs, false, symbolic_strides); - - auto run_kernel = [&](int dim1, int dim2) { - auto a = - at::rand({dim1, dim2}, at::TensorOptions(device).dtype(at::kFloat)); - auto b = - at::rand({dim1, dim2}, at::TensorOptions(device).dtype(at::kFloat)); - - auto ref = at::mul(at::erf(at::tanh(a)), b); - - std::vector stack = fmap(std::vector({a, b})); - stack.emplace_back(dim1); - stack.emplace_back(dim2); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - }; - - // Run the kernel in parallel to ensure that the run() method calls in - // TensorExprKernel are not changing any state. - constexpr size_t kNumThreads = 4; - std::vector threads; - for (size_t id = 0; id < kNumThreads; ++id) { - threads.emplace_back(run_kernel, id + 5, id + 20); - } - for (auto& t : threads) { - t.join(); - } - } -#endif -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_expr.cpp b/test/cpp/tensorexpr/test_expr.cpp deleted file mode 100644 index eb2d6296b229..000000000000 --- a/test/cpp/tensorexpr/test_expr.cpp +++ /dev/null @@ -1,836 +0,0 @@ -#include - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { -using namespace torch::jit::tensorexpr; - -using SimpleIRExprEval = ExprEval; - -TEST(Expr, BasicValueTest) { - ExprHandle a = IntImm::make(2), b = IntImm::make(3); - ExprHandle c = Add::make(a, b); - SimpleIRExprEval eval(c); - ASSERT_EQ(eval.value(), 5); -} - -TEST(Expr, BasicValueTest02) { - ExprHandle a(2.0f); - ExprHandle b(3.0f); - ExprHandle c(4.0f); - ExprHandle d(5.0f); - ExprHandle f = (a + b) - (c + d); - SimpleIRExprEval eval(f); - ASSERT_EQ(eval.value(), -4.0f); -} - -TEST(Expr, IsChannelsLastContiguous) { - std::vector vars = { - VarHandle("var1", kLong), - VarHandle("var2", kLong), - VarHandle("var3", kLong), - VarHandle("var4", kLong), - VarHandle("var5", kLong)}; - - // { - // key: ndims, - // value: [ - // ... - // [dim_2, dim_1, ..., dim_n] - // ] - // } - using shapGenInfo = std::unordered_map>>; - - // { - // size: [ExprHandle_1, ExprHandle_2, ..., ExprHandle_n], - // strides: [ - // ... - // [ExprHandle_x, ExprHandle_y, ..., ExprHandle_z] - // ] - // } - using shapeInfo = - std::pair, std::vector>>; - - std::vector dims = {3, 4, 5}; - - std::unordered_map> dims_expr_vec_conf = { - {3, std::vector(vars.begin(), vars.begin() + 2)}, - {4, std::vector(vars.begin(), vars.begin() + 3)}, - {5, std::vector(vars.begin(), vars.begin() + 4)}, - }; - - shapGenInfo channels_last_cont_shape_conf = { - {3, {{1, 2, 0}}}, {4, {{1, 3, 2, 0}}}, {5, {{1, 4, 3, 2, 0}}}}; - shapGenInfo channels_last_non_cont_shape_conf = { - {3, {{2, 1, 0}, {1, 0, 2}}}, - {4, {{3, 1, 2, 0}, {1, 2, 3, 0}, {1, 0, 2, 3}}}, - {5, {{4, 3, 2, 1, 0}, {1, 3, 2, 4, 0}, {1, 4, 3, 2, 0}}}}; - - shapGenInfo cont_shape_conf = { - {3, {{0, 1, 2}}}, {4, {{0, 1, 2, 3}}}, {5, {{0, 1, 2, 3, 4}}}}; - - auto shape_gen_fn = [dims_expr_vec_conf]( - int ndims, shapGenInfo shape_gen_info) -> shapeInfo { - auto dims_expr_vec = dims_expr_vec_conf.at(ndims); - std::vector> strides_expr_vec; - for (size_t i = 0; i < strides_expr_vec.size(); i++) { - strides_expr_vec[i].resize(ndims); - } - - auto stride_gen_fn = [](int indicator, ExprHandle a, ExprHandle b) { - if (indicator % 2 == 0) { - return a * b; - } else { - return b * a; - } - }; - - auto stride_order_vec = shape_gen_info.at(ndims); - for (size_t i = 0; i < strides_expr_vec.size(); i++) { - auto stride_order = stride_order_vec[i]; - - strides_expr_vec[i][stride_order[0]] = 1; - for (size_t j = 1; j < stride_order.size(); j++) { - auto cur_dim_idx = stride_order[j]; - auto adjacent_dim_idx = stride_order[j - 1]; - - strides_expr_vec[i][cur_dim_idx] = stride_gen_fn( - i, - dims_expr_vec[adjacent_dim_idx], - strides_expr_vec[i][adjacent_dim_idx]); - } - } - - return {dims_expr_vec, strides_expr_vec}; - }; - - auto check_channels_last_fn = [](int ndims, BufHandle buf_handle) -> bool { - if (ndims == 3) { - return buf_handle.is_channels_last_1d_contiguous(); - } else if (ndims == 4) { - return buf_handle.is_contiguous(at::MemoryFormat::ChannelsLast); - } else { - return buf_handle.is_contiguous(at::MemoryFormat::ChannelsLast3d); - } - }; - - // channels-last contiguous - for (size_t i = 0; i < dims.size(); i++) { - auto shape_info = shape_gen_fn(dims[i], channels_last_cont_shape_conf); - for (size_t j = 0; j < shape_info.second.size(); j++) { - BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat); - ASSERT_EQ(check_channels_last_fn(dims[i], buf_handle), true); - } - } - - // channels-last non-contiguous - for (size_t i = 0; i < dims.size(); i++) { - auto shape_info = shape_gen_fn(dims[i], channels_last_non_cont_shape_conf); - for (size_t j = 0; j < shape_info.second.size(); j++) { - BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat); - ASSERT_EQ(check_channels_last_fn(dims[i], buf_handle), false); - } - } - - // contiguous - for (size_t i = 0; i < dims.size(); i++) { - auto shape_info = shape_gen_fn(dims[i], cont_shape_conf); - for (size_t j = 0; j < shape_info.second.size(); j++) { - BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat); - ASSERT_EQ(buf_handle.is_contiguous(), true); - } - } - - // non-contiguous - for (size_t i = 0; i < dims.size(); i++) { - auto shape_info = shape_gen_fn(dims[i], channels_last_cont_shape_conf); - for (size_t j = 0; j < shape_info.second.size(); j++) { - BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat); - ASSERT_EQ(buf_handle.is_contiguous(), false); - } - } -} - -TEST(Expr, LetTest01) { - VarHandle x("x", kFloat); - ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f)); - SimpleIRExprEval eval(body); - eval.bindVar(x, ExprHandle(3.f)); - ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); -} - -TEST(Expr, LetTest02) { - VarHandle x("x", kFloat); - VarHandle y("y", kFloat); - ExprHandle body = - ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f) * y); - SimpleIRExprEval eval(body); - eval.bindVar(x, ExprHandle(3.f)); - eval.bindVar(y, ExprHandle(6.f)); - ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4 * 6)); -} - -TEST(Expr, LetStmtTest01) { - BufHandle a_buf("a", {1}, kFloat); - BufHandle b_buf("b", {1}, kFloat); - - ExprHandle load_a = a_buf.load(0); - VarHandle var = VarHandle("v", kFloat); - StmtPtr let_store = Let::make(var, load_a); - StmtPtr store_b = b_buf.store({0}, var); - BlockPtr block = Block::make({let_store, store_b}); - - SimpleIREvaluator eval(block, {a_buf, b_buf}); - - PaddedBuffer a_v(1); - PaddedBuffer b_v(1); - PaddedBuffer b_ref(1); - - a_v(0) = 23; - b_ref(0) = a_v(0); - eval(a_v, b_v); - - ExpectAllNear(b_v, b_ref, 1e-5); -} - -TEST(Expr, IntTest) { - VarHandle x("x", kInt); - ExprHandle body = ExprHandle(2) + (x * ExprHandle(3) + ExprHandle(4)); - SimpleIRExprEval eval(body); - eval.bindVar(x, ExprHandle(3)); - ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); -} - -TEST(Expr, FloatTest) { - VarHandle x("x", kFloat); - ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f)); - SimpleIRExprEval eval(body); - eval.bindVar(x, ExprHandle(3.f)); - ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); -} - -TEST(Expr, ByteTest) { - VarHandle x("x", kByte); - ExprHandle body = ExprHandle((uint8_t)2) + - (x * ExprHandle((uint8_t)3) + ExprHandle((uint8_t)4)); - SimpleIRExprEval eval(body); - eval.bindVar(x, ExprHandle((uint8_t)3)); - ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); -} - -TEST(Expr, CharTest) { - VarHandle x("x", kChar); - ExprHandle body = ExprHandle((int8_t)2) + - (x * ExprHandle((int8_t)3) + ExprHandle((int8_t)4)); - SimpleIRExprEval eval(body); - eval.bindVar(x, ExprHandle((int8_t)3)); - ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); -} - -TEST(Expr, ShortTest) { - VarHandle x("x", kShort); - ExprHandle body = ExprHandle((int16_t)2) + - (x * ExprHandle((int16_t)3) + ExprHandle((int16_t)4)); - SimpleIRExprEval eval(body); - eval.bindVar(x, ExprHandle((int16_t)3)); - ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); -} - -TEST(Expr, LongTest) { - VarHandle x("x", kLong); - ExprHandle body = ExprHandle((int64_t)2) + - (x * ExprHandle((int64_t)3) + ExprHandle((int64_t)4)); - SimpleIRExprEval eval(body); - eval.bindVar(x, ExprHandle((int64_t)3)); - ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); -} - -TEST(Expr, HalfTest) { - VarHandle x("x", kHalf); - ExprHandle body = ExprHandle((at::Half)2) + - (x * ExprHandle((at::Half)3) + ExprHandle((at::Half)4)); - SimpleIRExprEval eval(body); - eval.bindVar(x, ExprHandle((at::Half)3)); - ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); -} - -TEST(Expr, DoubleTest) { - VarHandle x("x", kDouble); - ExprHandle body = ExprHandle((double)2) + - (x * ExprHandle((double)3) + ExprHandle((double)4)); - SimpleIRExprEval eval(body); - eval.bindVar(x, ExprHandle((double)3)); - ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); -} - -TEST(Expr, VectorAdd01) { - const int kVectorSize = 8; - const int kVectorCount = 128; - const int kTotalSize = kVectorSize * kVectorCount; - - BufHandle a_buf("A", {kTotalSize}, kFloat); - BufHandle b_buf("B", {kTotalSize}, kFloat); - BufHandle c_buf("C", {kTotalSize}, kFloat); - - /* - Build the following: - for (const auto index : c10::irange(kVectorCount)) { - store(c_buf, ramp(index * 8, 1, 8), - load(a_buf, ramp(index * 8, 1, 8) + - load(b_buf, ramp(index * 8, 1, 8)))) - } - */ - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = - a_buf.load({Ramp::make(index * kVectorSize, 1, kVectorSize)}); - ExprHandle load_b = - b_buf.load({Ramp::make(index * kVectorSize, 1, kVectorSize)}); - ExprHandle value = load_a + load_b; - StmtPtr store_c = - c_buf.store({Ramp::make(index * kVectorSize, 1, kVectorSize)}, value); - StmtPtr stmt = For::make(index, 0, kVectorCount, store_c); - - ASSERT_EQ(load_a.dtype(), Dtype(kFloat, kVectorSize)); - ASSERT_EQ(load_b.dtype(), Dtype(kFloat, kVectorSize)); - ASSERT_EQ(value.dtype(), Dtype(kFloat, kVectorSize)); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - PaddedBuffer c_v(kTotalSize); - PaddedBuffer c_ref(kTotalSize); - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = i * i; - b_v(i) = i * i * 4; - c_ref(i) = a_v(i) + b_v(i); - } - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); - ir_eval(a_v, b_v, c_v); - ExpectAllNear(c_v, c_ref, 1e-5); -} - -TEST(Expr, CompareSelectEQ) { - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 1); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 0); - std::vector c_ref(N, 0); - - VarHandle i("i", kInt); - auto memcpy_expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kEQ))); - - SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); - ir_eval(a_buffer, b_buffer, c_buffer); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - - assertAllEqual(a_buffer, 1); - assertAllEqual(b_buffer, 1); - assertAllEqual(c_buffer, 1); -} - -TEST(Expr, CompareSelectDtypes) { - // LHS and RHS expressions should have the same dtype, but this dtype could - // differ from the dtype of the return values (but dtypes of true and false - // return values should be the same). - // This test constructs a CompareSelect expression where the input dtype is - // different from the output dtype and verifies that it works correctly: - // result = ((int)lhs == (int)rhs) ? (float)retval1 : (float)retval2 - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kFloat); - std::vector a_buffer(N, 1); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 0.0f); - std::vector c_ref(N, 3.14f); - - VarHandle i("i", kInt); - // C[i] = (A[i] == B[i]) ? 3.14f : 2.78f - // A and B are int, C is float. - auto select_expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), - b.load(i), - FloatImm::make(3.14f), - FloatImm::make(2.78f), - CompareSelectOperation::kEQ))); - - SimpleIREvaluator ir_eval(select_expr, {a, b, c}); - ir_eval(a_buffer, b_buffer, c_buffer); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - - assertAllEqual(a_buffer, 1); - assertAllEqual(b_buffer, 1); - ExpectAllNear(c_buffer, c_ref, 1e-7); -} - -TEST(Expr, IntrinsicsDtypes) { - constexpr int N = 256; - BufHandle a("A", {N}, kDouble); - BufHandle b("B", {N}, kDouble); - std::vector a_buffer(N, -10.0); - std::vector b_buffer(N, 0.0); - std::vector b_ref(N, 10.0); - - VarHandle i("i", kInt); - auto abs_expr = For::make(i, 0, N, b.store({i}, tensorexpr::abs(a.load(i)))); - - SimpleIREvaluator ir_eval(abs_expr, {a, b}); - ir_eval(a_buffer, b_buffer); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - - assertAllEqual(a_buffer, -10.0); - ExpectAllNear(b_buffer, b_ref, 1e-7); -} - -TEST(Expr, Substitute01) { - VarPtr x = alloc("x", kFloat); - VarPtr y = alloc("y", kFloat); - ExprPtr e = - alloc(alloc(x, alloc(1.0f)), alloc(x, y)); - - VarPtr z = alloc("z", kFloat); - ExprPtr e2 = Substitute(e, {{x, alloc(z, alloc(5.0f))}}); - ExprPtr e2_ref = alloc( - alloc(alloc(z, alloc(5.0f)), alloc(1.0f)), - alloc(alloc(z, alloc(5.0f)), y)); - std::ostringstream oss; - oss << *e2; - std::string e2_str = oss.str(); - - oss.str(""); - oss << *e2_ref; - std::string e2_ref_str = oss.str(); - ASSERT_EQ(e2_str, e2_ref_str); -} - -TEST(Expr, Math01) { - ExprHandle v = sin(ExprHandle(1.0f)); - - std::ostringstream oss; - oss << v; - ASSERT_EQ(oss.str(), "sin(1.f)"); - - SimpleIRExprEval eval(v); - float v_ref = std::sin(1.0f); - float res = eval.value(); - ASSERT_NEAR(res, v_ref, 1e-6); -} - -TEST(Expr, UnaryMath01) { - struct TestConfig { - std::function func; - std::function ref_func; - }; - - std::vector test_configs = { - {[](const ExprHandle& v) { return sin(v); }, - [](float v) { return std::sin(v); }}, - {[](const ExprHandle& v) { return sin(v); }, - [](float v) { return std::sin(v); }}, - {[](const ExprHandle& v) { return tan(v); }, - [](float v) { return std::tan(v); }}, - {[](const ExprHandle& v) { return asin(v); }, - [](float v) { return std::asin(v); }}, - {[](const ExprHandle& v) { return acos(v); }, - [](float v) { return std::acos(v); }}, - {[](const ExprHandle& v) { return atan(v); }, - [](float v) { return std::atan(v); }}, - {[](const ExprHandle& v) { return sinh(v); }, - [](float v) { return std::sinh(v); }}, - {[](const ExprHandle& v) { return cosh(v); }, - [](float v) { return std::cosh(v); }}, - {[](const ExprHandle& v) { return tanh(v); }, - [](float v) { return std::tanh(v); }}, - {[](const ExprHandle& v) { return exp(v); }, - [](float v) { return std::exp(v); }}, - {[](const ExprHandle& v) { return tensorexpr::abs(v); }, - [](float v) { return std::fabs(v); }}, - {[](const ExprHandle& v) { return log(v); }, - [](float v) { return std::log(v); }}, - {[](const ExprHandle& v) { return log2(v); }, - [](float v) { return std::log2(v); }}, - {[](const ExprHandle& v) { return log10(v); }, - [](float v) { return std::log10(v); }}, - {[](const ExprHandle& v) { return erf(v); }, - [](float v) { return std::erf(v); }}, - {[](const ExprHandle& v) { return sqrt(v); }, - [](float v) { return std::sqrt(v); }}, - {[](const ExprHandle& v) { return rsqrt(v); }, - [](float v) { return 1.0f / std::sqrt(v); }}, - {[](const ExprHandle& v) { return ceil(v); }, - [](float v) { return std::ceil(v); }}, - {[](const ExprHandle& v) { return floor(v); }, - [](float v) { return std::floor(v); }}, - {[](const ExprHandle& v) { return round(v); }, - [](float v) { return std::round(v); }}, - {[](const ExprHandle& v) { return trunc(v); }, - [](float v) { return std::trunc(v); }}, - }; - - for (const TestConfig& test_config : test_configs) { - const float input_v = 0.8765f; - ExprHandle v = test_config.func(ExprHandle(input_v)); - float v_ref = test_config.ref_func(input_v); - SimpleIRExprEval eval(v); - ASSERT_NEAR(eval.value(), v_ref, 1e-6); - } - - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - for (float input_v : {std::nan("1"), 0., .5}) { - ExprHandle v = FloatImm::make(input_v); - SimpleIRExprEval eval(Intrinsics::make(kIsNan, v)); - ASSERT_NEAR(eval.value(), std::isnan(input_v), 0); - } -} - -TEST(Expr, BinaryMath01) { - struct TestConfig { - std::function func; - std::function ref_func; - }; - - std::vector test_configs = { - {[](const ExprHandle& v1, const ExprHandle& v2) { return pow(v1, v2); }, - [](float v1, float v2) { return std::pow(v1, v2); }}, - {[](const ExprHandle& v1, const ExprHandle& v2) { return fmod(v1, v2); }, - [](float v1, float v2) { return std::fmod(v1, v2); }}, - }; - - for (const TestConfig& test_config : test_configs) { - const float v1 = 0.8765f; - float v2 = 1.2345f; - ExprHandle v_expr = test_config.func(ExprHandle(v1), ExprHandle(v2)); - float v_ref = test_config.ref_func(v1, v2); - SimpleIRExprEval eval(v_expr); - ASSERT_NEAR(eval.value(), v_ref, 1e-6); - } -} - -TEST(Expr, LogicalOps01) { - ExprHandle a(23); - ExprHandle b(11); - ExprHandle c(0.72f); - ExprHandle d(0.69f); - ExprHandle f1 = (a > b) && (c > d); - ExprHandle f2 = (a > b) && (c < d); - ExprHandle f3 = (a < b) && (c > d); - ExprHandle f4 = (a < b) && (c < d); - ExprHandle f5 = (a < b) || (c > d); - ExprHandle f6 = (a < b) || (c < d); - ExprHandle f7 = (a > b) || (c < d); - ExprHandle f8 = (a > b) || (c > d); - - SimpleIRExprEval eval1(f1); - SimpleIRExprEval eval2(f2); - SimpleIRExprEval eval3(f3); - SimpleIRExprEval eval4(f4); - SimpleIRExprEval eval5(f5); - SimpleIRExprEval eval6(f6); - SimpleIRExprEval eval7(f7); - SimpleIRExprEval eval8(f8); - ASSERT_EQ(eval1.value(), 1); - ASSERT_EQ(eval2.value(), 0); - ASSERT_EQ(eval3.value(), 0); - ASSERT_EQ(eval4.value(), 0); - ASSERT_EQ(eval5.value(), 1); - ASSERT_EQ(eval6.value(), 0); - ASSERT_EQ(eval7.value(), 1); - ASSERT_EQ(eval8.value(), 1); -} - -TEST(Expr, LogicalOps02) { - ExprHandle a(23); - ExprHandle b(11); - ExprHandle c(0.72f); - ExprHandle d(0.72f); - - ExprHandle f1 = (a > b) || (c > d); - ExprHandle f2 = (a > b) && (c <= d); - ExprHandle f3 = (a > b) && (c > d); - ExprHandle ff1 = f1 && f2; - ExprHandle ff2 = f2 || f3; - - SimpleIRExprEval eval1(ff1); - SimpleIRExprEval eval2(ff2); - ASSERT_EQ(eval1.value(), 1); - ASSERT_EQ(eval2.value(), 1); -} - -TEST(Expr, LogicalOps03) { - ExprHandle a(23); - ExprHandle b(11); - ExprHandle c(0.72f); - ExprHandle d(0.69f); - - // Bool types - ExprHandle bool_f1 = (a > b) && BoolImm::make(true); - ExprHandle bool_f2 = (c <= d) || BoolImm::make(true); - - // Int types - ExprHandle int_f1 = (a > b) && IntImm::make(1); - ExprHandle int_f2 = (c <= d) || IntImm::make(1); - - // Short types - ExprHandle short_f1 = (a > b) && ShortImm::make(1); - ExprHandle short_f2 = (c <= d) || ShortImm::make(1); - - // Long types - ExprHandle long_f1 = (a > b) && LongImm::make(1); - ExprHandle long_f2 = (c <= d) || LongImm::make(1); - - // Char types - ExprHandle char_f1 = (a > b) && CharImm::make(1); - ExprHandle char_f2 = (c <= d) || CharImm::make(1); - - // Byte types - ExprHandle byte_f1 = (a > b) && ByteImm::make(1); - ExprHandle byte_f2 = (c <= d) || ByteImm::make(1); - - SimpleIRExprEval eval1(bool_f1); - SimpleIRExprEval eval2(bool_f2); - SimpleIRExprEval eval3(int_f1); - SimpleIRExprEval eval4(int_f2); - SimpleIRExprEval eval5(short_f1); - SimpleIRExprEval eval6(short_f2); - SimpleIRExprEval eval7(long_f1); - SimpleIRExprEval eval8(long_f2); - SimpleIRExprEval eval9(char_f1); - SimpleIRExprEval eval10(char_f2); - SimpleIRExprEval eval11(byte_f1); - SimpleIRExprEval eval12(byte_f2); - - ASSERT_EQ(eval1.value(), true); - ASSERT_EQ(eval2.value(), true); - ASSERT_EQ(eval3.value(), 1); - ASSERT_EQ(eval4.value(), 1); - ASSERT_EQ(eval5.value(), 1); - ASSERT_EQ(eval6.value(), 1); - ASSERT_EQ(eval7.value(), 1); - ASSERT_EQ(eval8.value(), 1); - ASSERT_EQ(eval9.value(), 1); - ASSERT_EQ(eval10.value(), 1); - ASSERT_EQ(eval11.value(), 1); - ASSERT_EQ(eval12.value(), 1); -} - -TEST(Expr, BitwiseOps) { - ExprHandle a(59); - ExprHandle b(11); - ExprHandle c(101); - ExprHandle d(2); - ExprHandle f = (((a ^ (b << 1)) & c) >> 2) | d; - - SimpleIRExprEval eval(f); - ASSERT_EQ(eval.value(), 11); -} - -TEST(Expr, DynamicShapeAdd) { - auto testWithSize = [](int32_t size) { - VarHandle n("n", kInt); - BufHandle a("a", {n}, kFloat); - BufHandle b("b", {n}, kFloat); - BufHandle c("c", {n}, kFloat); - VarHandle i("i", kInt); - StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i))); - std::vector aData(size, 1.0f); - std::vector bData(size, 2.0f); - std::vector cData(size, 0.0f); - SimpleIREvaluator(s, {a, b, c, n})(aData, bData, cData, size); - ExpectAllNear(cData, std::vector(size, 3.0f), 1e-7); - }; - testWithSize(1); - testWithSize(16); - testWithSize(37); -} - -TEST(Expr, OutOfBounds) { - ExprHandle N(10); - ExprHandle start(0); - ExprHandle stop(15); - VarHandle i("i", kInt); - - BufHandle X("X", {N}, kInt); - - auto body = Store::make(X, {i}, i); - auto stmt = For::make(i, start, stop, body); - - PaddedBuffer data(20); - - EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data)); -} - -TEST(Expr, OutOfBounds2d) { - std::vector> size_options = {{10, 15}, {15, 10}}; - for (auto sizes : size_options) { - ExprHandle N(sizes.first); - ExprHandle M(sizes.second); - ExprHandle start(0); - ExprHandle stopInner(15); - ExprHandle stopOuter(15); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - - BufHandle X("X", {N, M}, kInt); - - auto body = Store::make(X, {i, j}, i); - auto inner = For::make(j, start, stopInner, body); - auto stmt = For::make(i, start, stopOuter, inner); - - PaddedBuffer data(400); - - EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data)); - } -} - -TEST(Expr, OutOfBounds2dFlattenedIndex) { - ExprHandle buf_size(149); - ExprHandle start(0); - ExprHandle stopInner(15); - ExprHandle stopOuter(10); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - - BufHandle X("X", {buf_size}, kInt); - - auto idx = Add::make(Mul::make(i, stopInner), j); - auto body = Store::make(X, {idx}, i); - auto inner = For::make(j, start, stopInner, body); - auto stmt = For::make(i, start, stopOuter, inner); - - PaddedBuffer data(400); - - EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data)); -} - -void testCond01() { - const int N = 16; - PaddedBuffer a_v(N); - BufHandle a_buf("a", {N}, kFloat); - VarHandle index = VarHandle("index", kInt); - StmtPtr assign_x2 = a_buf.store({index}, cast(index) * 2); - StmtPtr assign_x3 = a_buf.store({index}, cast(index) * 3); - ExprHandle even_cond = CompareSelect::make(Mod::make(index, 2), 0, kEQ); - StmtPtr assign = Cond::make(even_cond, assign_x2, assign_x3); - StmtPtr for_stmt = For::make(index, 0, N, assign); - SimpleIREvaluator(for_stmt, {a_buf})(a_v); - - PaddedBuffer a_ref(N); - for (const auto i : c10::irange(N)) { - if (i % 2 == 0) { - a_ref(i) = i * 2; - } else { - a_ref(i) = i * 3; - } - } - ExpectAllNear(a_v, a_ref, 1e-5); -} - -void testIfThenElse01() { - ExprHandle v = ifThenElse(ExprHandle(1), ExprHandle(1.0f), ExprHandle(2.0f)); - - std::ostringstream oss; - oss << v; - ASSERT_EQ(oss.str(), "IfThenElse(1, 1.f, 2.f)"); - - SimpleIRExprEval eval(v); - ASSERT_EQ(eval.value(), 1.0f); -} - -void testIfThenElse02() { - ExprHandle v = ifThenElse(ExprHandle(0), ExprHandle(1.0f), ExprHandle(2.0f)); - - std::ostringstream oss; - oss << v; - ASSERT_EQ(oss.str(), "IfThenElse(0, 1.f, 2.f)"); - - SimpleIRExprEval eval(v); - ASSERT_EQ(eval.value(), 2.0f); -} - -void testIfThenElse03() { - ExprHandle v = - ifThenElse(BoolImm::make(false), ExprHandle(1.0f), ExprHandle(2.0f)); - - std::ostringstream oss; - oss << v; - ASSERT_EQ(oss.str(), "IfThenElse(0, 1.f, 2.f)"); - - SimpleIRExprEval eval(v); - ASSERT_EQ(eval.value(), 2.0f); -} - -void testStmtClone() { - const int N = 16; - - BufHandle a_buf("a", {N}, kInt); - VarHandle index = VarHandle("index", kInt); - StmtPtr body = a_buf.store({index}, 5); - StmtPtr loop = For::make(index, 0, N, body); - - StmtPtr cloned_loop = Stmt::clone(loop); - std::vector orig_loop_results(N); - std::vector cloned_loop_results(N); - SimpleIREvaluator(loop, {a_buf})(orig_loop_results); - SimpleIREvaluator(cloned_loop, {a_buf})(cloned_loop_results); - - assertAllEqual(orig_loop_results, 5); - assertAllEqual(cloned_loop_results, 5); - - // Let's add another assign to the body in the cloned loop and verify that the - // original statement hasn't changed while the cloned one has. - StmtPtr body_addition = a_buf.store({index}, 33); - BlockPtr cloned_body = static_to(static_to(cloned_loop)->body()); - cloned_body->append_stmt(body_addition); - - std::vector orig_loop_results_after_mutation(N); - std::vector cloned_loop_results_after_mutation(N); - SimpleIREvaluator(loop, {a_buf})(orig_loop_results_after_mutation); - SimpleIREvaluator(cloned_loop, {a_buf})(cloned_loop_results_after_mutation); - - assertAllEqual(orig_loop_results_after_mutation, 5); - assertAllEqual(cloned_loop_results_after_mutation, 33); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_external_calls.cpp b/test/cpp/tensorexpr/test_external_calls.cpp deleted file mode 100644 index 49f43d16b499..000000000000 --- a/test/cpp/tensorexpr/test_external_calls.cpp +++ /dev/null @@ -1,1061 +0,0 @@ -#include - -#include - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include -#include -#include - -namespace torch { -namespace jit { -using namespace torch::jit::tensorexpr; - -TEST(ExternalCall, Conv1d_float) { - BufHandle Input("Input", {1, 100, 115}, kFloat); - BufHandle Weight("Weight", {100, 1, 7}, kFloat); - BufHandle Bias("Bias", {100}, kFloat); - BufHandle ResultBuf("Result", {1, 100, 115}, kFloat); - int64_t stride = 1; - int64_t pad = 3; - int64_t dilation = 1; - int64_t groups = 100; - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make( - ResultBuf, - "nnc_aten_conv1d", - {Input, Weight, Bias}, - {stride, pad, dilation, groups})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor input = at::ones({1, 100, 115}, options) * 5.f; - at::Tensor weight = at::ones({100, 1, 7}, options) * 6.f; - at::Tensor bias = at::ones({100}, options) * 11.f; - at::Tensor ref = - at::conv1d(input, weight, bias, {stride}, {pad}, {dilation}, groups); - - at::Tensor nnc_result; - std::vector input_buf(1 * 100 * 115, 5.f); - std::vector weight_buf(100 * 1 * 7, 6.f); - std::vector bias_buf(100, 11.f); - std::vector result_buf(1 * 100 * 115, -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result}); - - llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result}); - - ir_eval.call({input_buf, weight_buf, bias_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, Conv1d_int) { - // A similar test, but now using kInt tensors - BufHandle Input("Input", {1, 100, 115}, kInt); - BufHandle Weight("Weight", {100, 1, 7}, kInt); - BufHandle Bias("Bias", {100}, kInt); - BufHandle ResultBuf("Result", {1, 100, 115}, kInt); - int64_t stride = 1; - int64_t pad = 3; - int64_t dilation = 1; - int64_t groups = 100; - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make( - ResultBuf, - "nnc_aten_conv1d", - {Input, Weight, Bias}, - {stride, pad, dilation, groups})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kInt) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor input = at::ones({1, 100, 115}, options) * 5; - at::Tensor weight = at::ones({100, 1, 7}, options) * 6; - at::Tensor bias = at::ones({100}, options) * 11; - at::Tensor ref = - at::conv1d(input, weight, bias, {stride}, {pad}, {dilation}, groups); - - at::Tensor nnc_result; - std::vector input_buf(1 * 100 * 115, 5); - std::vector weight_buf(100 * 1 * 7, 6); - std::vector bias_buf(100, 11); - std::vector result_buf(1 * 100 * 115, -1); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result}); - - llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result}); - - ir_eval.call({input_buf, weight_buf, bias_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, Conv1d_nobias_noargs) { - BufHandle Input("Input", {1, 1, 115}, kFloat); - BufHandle Weight("Weight", {10, 1, 7}, kFloat); - BufHandle ResultBuf("Result", {1, 10, 109}, kFloat); - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make(ResultBuf, "nnc_aten_conv1d", {Input, Weight}, {})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor input = at::ones({1, 1, 115}, options) * 5.f; - at::Tensor weight = at::ones({10, 1, 7}, options) * 6.f; - at::Tensor ref = at::conv1d(input, weight); - - at::Tensor nnc_result; - std::vector input_buf(1 * 1 * 115, 5.f); - std::vector weight_buf(10 * 1 * 7, 6.f); - std::vector result_buf(1 * 10 * 109, -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Result}); - - llvm_codegen.call({input_buf, weight_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 10, 109}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Result}); - - ir_eval.call({input_buf, weight_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 10, 109}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, Conv2d_float) { - BufHandle Input("Input", {1, 3, 224, 224}, kFloat); - BufHandle Weight("Weight", {16, 3, 3, 3}, kFloat); - BufHandle Bias("Bias", {16}, kFloat); - BufHandle ResultBuf("Result", {1, 16, 112, 112}, kFloat); - int64_t stride = 2; - int64_t pad = 1; - int64_t dilation = 1; - int64_t groups = 1; - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make( - ResultBuf, - "nnc_aten_conv2d", - {Input, Weight, Bias}, - {stride, stride, pad, pad, dilation, dilation, groups})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor input = at::ones({1, 3, 224, 224}, options) * 5.f; - at::Tensor weight = at::ones({16, 3, 3, 3}, options) * 6.f; - at::Tensor bias = at::ones({16}, options) * 11.f; - at::Tensor ref = at::conv2d( - input, - weight, - bias, - {stride, stride}, - {pad, pad}, - {dilation, dilation}, - groups); - - at::Tensor nnc_result; - std::vector input_buf(1 * 3 * 224 * 224, 5.f); - std::vector weight_buf(16 * 3 * 3 * 3, 6.f); - std::vector bias_buf(16, 11.f); - std::vector result_buf(1 * 16 * 112 * 112, -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result}); - - llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result}); - - ir_eval.call({input_buf, weight_buf, bias_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, Conv2d_int) { - // A similar test, but now using kInt tensors - - BufHandle Input("Input", {1, 3, 224, 224}, kInt); - BufHandle Weight("Weight", {16, 3, 3, 3}, kInt); - BufHandle Bias("Bias", {16}, kInt); - BufHandle ResultBuf("Result", {1, 16, 112, 112}, kInt); - int64_t stride = 2; - int64_t pad = 1; - int64_t dilation = 1; - int64_t groups = 1; - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make( - ResultBuf, - "nnc_aten_conv2d", - {Input, Weight, Bias}, - {stride, stride, pad, pad, dilation, dilation, groups})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kInt) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor input = at::ones({1, 3, 224, 224}, options) * 5; - at::Tensor weight = at::ones({16, 3, 3, 3}, options) * 6; - at::Tensor bias = at::ones({16}, options) * 11; - at::Tensor ref = at::conv2d( - input, - weight, - bias, - {stride, stride}, - {pad, pad}, - {dilation, dilation}, - groups); - - at::Tensor nnc_result; - std::vector input_buf(1 * 3 * 224 * 224, 5); - std::vector weight_buf(16 * 3 * 3 * 3, 6); - std::vector bias_buf(16, 11); - std::vector result_buf(1 * 16 * 112 * 112, -1); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result}); - - llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result}); - - ir_eval.call({input_buf, weight_buf, bias_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, Conv2d_nobias_noargs) { - BufHandle Input("Input", {1, 16, 112, 112}, kFloat); - BufHandle Weight("Weight", {16, 16, 1, 1}, kFloat); - BufHandle ResultBuf("Result", {1, 16, 112, 112}, kFloat); - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make(ResultBuf, "nnc_aten_conv2d", {Input, Weight}, {})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor input = at::ones({1, 16, 112, 112}, options) * 5.f; - at::Tensor weight = at::ones({16, 16, 1, 1}, options) * 6.f; - at::Tensor ref = at::conv2d(input, weight); - - at::Tensor nnc_result; - std::vector input_buf(1 * 16 * 112 * 112, 5.f); - std::vector weight_buf(16 * 16 * 1 * 1, 6.f); - std::vector result_buf(1 * 16 * 112 * 112, -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Result}); - - llvm_codegen.call({input_buf, weight_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Result}); - - ir_eval.call({input_buf, weight_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, Addmm_float) { - BufHandle Input("Input", {100, 300}, kFloat); - BufHandle Mat1("Mat1", {100, 200}, kFloat); - BufHandle Mat2("Mat2", {200, 300}, kFloat); - BufHandle ResultBuf("Result", {100, 300}, kFloat); - int64_t beta = 2; - int64_t alpha = 2; - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make( - ResultBuf, "nnc_aten_addmm", {Input, Mat1, Mat2}, {beta, alpha})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor input = at::ones({100, 300}, options) * 5.f; - at::Tensor mat1 = at::ones({100, 200}, options) * 6.f; - at::Tensor mat2 = at::ones({200, 300}, options) * 11.f; - at::Tensor ref = at::addmm(input, mat1, mat2, beta, alpha); - - at::Tensor nnc_result; - std::vector input_buf(100 * 300, 5.f); - std::vector mat1_buf(100 * 200, 6.f); - std::vector mat2_buf(200 * 300, 11.f); - std::vector result_buf(100 * 300, -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Mat1, Mat2, Result}); - - llvm_codegen.call({input_buf, mat1_buf, mat2_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {100, 300}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Mat1, Mat2, Result}); - - ir_eval.call({input_buf, mat1_buf, mat2_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {100, 300}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, Embedding) { - BufHandle Weight("Weight", {256, 100}, kFloat); - BufHandle Indices("Indices", {1, 115}, kLong); - BufHandle ResultBuf("Result", {1, 115, 100}, kFloat); - int64_t padding_idx = -1; - bool scale_grad_by_freq = false; - bool sparse = false; - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make( - ResultBuf, - "nnc_aten_embedding", - {Weight, Indices}, - {padding_idx, (int64_t)scale_grad_by_freq, (int64_t)sparse})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - - at::Tensor weight = at::ones({256, 100}, options.dtype(at::kFloat)) * 5.f; - at::Tensor indices = at::ones({1, 115}, options.dtype(at::kLong)) * 6; - at::Tensor ref = - at::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse); - - at::Tensor nnc_result; - std::vector weight_buf(256 * 100, 5.f); - std::vector indices_buf(1 * 115, 6); - std::vector result_buf(1 * 115 * 100, -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Weight, Indices, Result}); - - llvm_codegen.call({weight_buf, indices_buf, result_buf}); - nnc_result = at::from_blob( - result_buf.data(), {1, 115, 100}, options.dtype(at::kFloat)); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Weight, Indices, Result}); - - ir_eval.call({weight_buf, indices_buf, result_buf}); - nnc_result = at::from_blob( - result_buf.data(), {1, 115, 100}, options.dtype(at::kFloat)); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, MaxReduction) { - BufHandle Input("Input", {1, 115, 152}, kFloat); - BufHandle ResultBuf("Result", {1, 152}, kFloat); - int64_t dim = 1; - bool keep_dim = false; - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make( - ResultBuf, "nnc_aten_max_red", {Input}, {dim, (int64_t)keep_dim})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - - at::Tensor input = at::ones({1, 115, 152}, options) * 5.f; - at::Tensor ref = std::get<0>(at::max(input, dim, keep_dim)); - - at::Tensor nnc_result; - std::vector input_buf(1 * 115 * 152, 5.f); - std::vector result_buf(1 * 152, -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Result}); - - llvm_codegen.call({input_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 152}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Result}); - - ir_eval.call({input_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 152}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -#ifdef USE_XNNPACK - -TEST(ExternalCall, Prepacked_Linear_float) { - using namespace at::native::xnnpack; - - BufHandle Input("Input", {100, 200}, kFloat); - BufHandle ResultBuf("Result", {100, 300}, kFloat); - - // Calculate reference result using at::linear. - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor input = - at::linspace(-10.0, 10.0, 100 * 200, options).resize_({100, 200}); - at::Tensor weight = - at::linspace(-10.0, 10.0, 300 * 200, options).resize_({300, 200}); - at::Tensor bias = at::linspace(-10.0, 10.0, 300, options); - at::Tensor ref = at::linear(input, weight, bias); - - // Create prepacked xnnpack context object. - auto linear_clamp_prepack_op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("prepacked::linear_clamp_prepack", "") - .typed( - at::Tensor, - std::optional, - const std::optional&, - const std::optional&)>(); - auto prepacked = linear_clamp_prepack_op.call( - weight, bias, std::optional(), std::optional()); - - BufHandle DummyPrepacked("DummyPrepacked", {1}, kFloat); - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make( - ResultBuf, - "nnc_prepacked_linear_clamp_run", - {Input, DummyPrepacked}, - {})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - at::Tensor nnc_result; - std::vector input_buf( - input.data_ptr(), input.data_ptr() + 100 * 200); - std::vector result_buf(100 * 300, -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, DummyPrepacked, Result}); - - llvm_codegen.call({input_buf, prepacked.get(), result_buf}); - nnc_result = at::from_blob(result_buf.data(), {100, 300}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Input, DummyPrepacked, Result}); - - ir_eval.call({input_buf, prepacked.get(), result_buf}); - nnc_result = at::from_blob(result_buf.data(), {100, 300}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, Prepacked_Conv2d_float) { - using namespace at::native::xnnpack; - - BufHandle Input("Input", {1, 3, 224, 224}, kFloat); - BufHandle ResultBuf("Result", {1, 16, 112, 112}, kFloat); - int64_t stride = 2; - int64_t pad = 1; - int64_t dilation = 1; - int64_t groups = 1; - - // Calculate reference result using at::conv2d. - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor input = at::linspace(-10.0, 10.0, 1 * 3 * 224 * 224, options) - .resize_({1, 3, 224, 224}); - at::Tensor weight = - at::linspace(-10.0, 10.0, 16 * 3 * 3 * 3, options).resize_({16, 3, 3, 3}); - at::Tensor bias = at::linspace(-10.0, 10.0, 16, options); - at::Tensor ref = at::conv2d( - input, - weight, - bias, - {stride, stride}, - {pad, pad}, - {dilation, dilation}, - groups); - - // Create prepacked xnnpack context object. - auto conv2d_clamp_prepack_op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("prepacked::conv2d_clamp_prepack", "") - .typed( - at::Tensor, - std::optional, - std::vector, - std::vector, - std::vector, - int64_t, - const std::optional&, - const std::optional&)>(); - auto prepacked = conv2d_clamp_prepack_op.call( - weight, - bias, - {stride, stride}, - {pad, pad}, - {dilation, dilation}, - groups, - std::optional(), - std::optional()); - - BufHandle DummyPrepacked("DummyPrepacked", {1}, kFloat); - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make( - ResultBuf, - "nnc_prepacked_conv2d_clamp_run", - {Input, DummyPrepacked}, - {})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - at::Tensor nnc_result; - std::vector input_buf( - input.data_ptr(), input.data_ptr() + 1 * 3 * 224 * 224); - std::vector result_buf(1 * 16 * 112 * 112, -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, DummyPrepacked, Result}); - - llvm_codegen.call({input_buf, prepacked.get(), result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref, 1e-03, 1e-03)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Input, DummyPrepacked, Result}); - - ir_eval.call({input_buf, prepacked.get(), result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref, 1e-03, 1e-03)); -} - -#endif // USE_XNNPACK - -TEST(ExternalCall, BinaryFloat) { - using TensorFunc = std::function; - using Test = std::tuple< - std::vector, - std::vector, - std::vector, - TensorFunc, - std::string>; - std::vector tests = {}; - tests.push_back( - Test{{100, 200}, {200, 300}, {100, 300}, at::matmul, "nnc_aten_matmul"}); - tests.push_back(Test{{100, 300}, {300}, {100}, at::mv, "nnc_aten_mv"}); - tests.push_back(Test{ - {100, 200}, - {200, 300}, - {100, 300}, - [&](const at::Tensor& a, const at::Tensor& b) { return at::mm(a, b); }, - "nnc_aten_mm"}); - for (auto curTest : tests) { - auto [aShape, bShape, resShape, torchFunc, externCallName] = curTest; - auto toExprHandleVec = [](std::vector v) { - auto intV = std::vector(v.begin(), v.end()); - return std::vector(intV.begin(), intV.end()); - }; - BufHandle A("A", toExprHandleVec(aShape), kFloat); - BufHandle B("B", toExprHandleVec(bShape), kFloat); - BufHandle ResultBuf("Result", toExprHandleVec(resShape), kFloat); - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make(ResultBuf, externCallName, {A, B}, {})); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor a = at::ones(c10::IntArrayRef(aShape), options) * 5.f; - at::Tensor b = at::ones(c10::IntArrayRef(bShape), options) * 6.f; - at::Tensor ref = torchFunc(a, b); - - auto prod = [](std::vector v) { - // NOLINTNEXTLINE(modernize-use-transparent-functors) - return std::accumulate(v.begin(), v.end(), 1, std::multiplies()); - }; - - at::Tensor nnc_result; - std::vector a_buf(prod(aShape), 5.f); - std::vector b_buf(prod(bShape), 6.f); - std::vector result_buf(prod(resShape), -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {A, B, Result}); - - llvm_codegen.call({a_buf, b_buf, result_buf}); - nnc_result = - at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {A, B, Result}); - ir_eval.call({a_buf, b_buf, result_buf}); - nnc_result = - at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); - } -} - -TEST(ExternalCall, UnaryFloat) { - using TensorFunc = std::function; - auto toExprHandleVec = [](std::vector v) { - auto intV = std::vector(v.begin(), v.end()); - return std::vector(intV.begin(), intV.end()); - }; - using Test = std::tuple< - std::vector, - std::vector, - TensorFunc, - std::string, - std::vector>; - std::vector tests = {}; - tests.push_back(Test{ - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - {1, 64, 8, 9}, - {1, 64, 5, 7}, - [](at::Tensor x) { return at::adaptive_avg_pool2d(x, {5, 7}); }, - "nnc_aten_adaptive_avg_pool2d", - toExprHandleVec({5, 7})}); - tests.push_back(Test{// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - {100, 200}, - {100}, - [](at::Tensor x) { return at::mean(x, {1}); }, - "nnc_aten_mean", - toExprHandleVec({1, /*keepdim=*/0})}); - for (auto curTest : tests) { - auto [aShape, resShape, torchFunc, externCallName, externCallArgs] = - curTest; - BufHandle A("A", toExprHandleVec(aShape), kFloat); - BufHandle ResultBuf("Result", toExprHandleVec(resShape), kFloat); - - Tensor Result = Tensor( - ResultBuf.node(), - ExternalCall::make(ResultBuf, externCallName, {A}, externCallArgs)); - LoopNest l({Result}); - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor a = at::ones(c10::IntArrayRef(aShape), options) * 5.f; - at::Tensor ref = torchFunc(a); - - auto prod = [](std::vector v) { - // NOLINTNEXTLINE(modernize-use-transparent-functors) - return std::accumulate(v.begin(), v.end(), 1, std::multiplies()); - }; - - at::Tensor nnc_result; - std::vector a_buf(prod(aShape), 5.f); - std::vector result_buf(prod(resShape), -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {A, Result}); - - llvm_codegen.call({a_buf, result_buf}); - nnc_result = - at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {A, Result}); - ir_eval.call({a_buf, result_buf}); - nnc_result = - at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); - } -} - -TEST(ExternalCall, ComputeInterop) { - // This test verifies that Tensors using external calls can be used by and can - // use Tensors built with Compute API. - - BufHandle ConvResultBuf("ConvResult", {1, 16, 32, 32}, kFloat); - BufHandle MatmulResultBuf("MatmulResult", {1, 16, 32, 32}, kFloat); - - Tensor Input = Compute( - "Input", - {1, 16, 32, 32}, - [&](const VarHandle& n, - const VarHandle& c, - const VarHandle& h, - const VarHandle& w) { return FloatImm::make(5.0f); }); - Tensor Weight = Compute( - "Weight", - {16, 16, 1, 1}, - [&](const VarHandle& n, - const VarHandle& c, - const VarHandle& h, - const VarHandle& w) { return FloatImm::make(6.0f); }); - - Tensor ConvResult = Tensor( - ConvResultBuf.node(), - ExternalCall::make( - ConvResultBuf, - "nnc_aten_conv2d", - {BufHandle(Input.buf()), BufHandle(Weight.buf())}, - {})); - Tensor MatmulResult = Tensor( - MatmulResultBuf.node(), - ExternalCall::make( - MatmulResultBuf, - "nnc_aten_matmul", - {BufHandle(ConvResult.buf()), BufHandle(ConvResult.buf())}, - {})); - Tensor Result = Compute( - "Result", - {1, 16, 32, 32}, - [&](const VarHandle& n, - const VarHandle& c, - const VarHandle& h, - const VarHandle& w) { - return ConvResult.load(n, c, h, w) + MatmulResult.load(n, c, h, w); - }); - - LoopNest l({Input, Weight, ConvResult, MatmulResult, Result}); - - // Inlining should not inline anything here since all Bufs are either defined - // or used in ExternalCalls - we run it just for testing - l.inlineIntermediateBufs(true); - - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor input = at::ones({1, 16, 32, 32}, options) * 5.f; - at::Tensor weight = at::ones({16, 16, 1, 1}, options) * 6.f; - at::Tensor t = at::conv2d(input, weight); - at::Tensor t2 = at::matmul(t, t); - at::Tensor ref = t + t2; - - at::Tensor nnc_result; - std::vector input_buf(1 * 16 * 32 * 32, 5.f); - std::vector weight_buf(16 * 16 * 1 * 1, 6.f); - std::vector conv_result_buf(1 * 16 * 32 * 32, -1.f); - std::vector matmul_result_buf(1 * 16 * 32 * 32, -1.f); - std::vector result_buf(1 * 16 * 32 * 32, -1.f); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen( - l.root_stmt(), {Input, Weight, ConvResult, MatmulResult, Result}); - - llvm_codegen.call( - {input_buf, weight_buf, conv_result_buf, matmul_result_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 16, 32, 32}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval( - l.root_stmt(), {Input, Weight, ConvResult, MatmulResult, Result}); - - ir_eval.call( - {input_buf, weight_buf, conv_result_buf, matmul_result_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 16, 32, 32}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, Inlining) { - // This test verifies that Tensors using external calls can be used by and - // can use Tensors built with Compute API. - - BufHandle MatmulResultBuf("MatmulResult", {8, 8}, kFloat); - - Tensor A = Compute("A", {8, 8}, [&](const VarHandle& i, const VarHandle& j) { - return FloatImm::make(5.0f); - }); - Tensor B = Compute("B", {8, 8}, [&](const VarHandle& i, const VarHandle& j) { - return FloatImm::make(4.0f); - }); - Tensor MatmulResult = Tensor( - MatmulResultBuf.node(), - ExternalCall::make( - MatmulResultBuf, - "nnc_aten_matmul", - {BufHandle(A.buf()), BufHandle(B.buf())}, - {})); - Tensor Result = - Compute("Result", {8, 8}, [&](const VarHandle& i, const VarHandle& j) { - return MatmulResult.load(i, j) + FloatImm::make(3.0f); - }); - - StmtPtr root_stmt = alloc(std::vector( - {A.stmt(), B.stmt(), MatmulResult.stmt(), Result.stmt()})); - LoopNest l(root_stmt, {Result.buf()}); - - // Inlining should not inline anything here since all Bufs are either - // defined or used in ExternalCalls - l.inlineIntermediateBufs(false); - - l.prepareForCodegen(); - l.simplify(); - - auto options = at::TensorOptions() - .dtype(at::kFloat) - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - at::Tensor a = at::ones({8, 8}, options) * 5.f; - at::Tensor b = at::ones({8, 8}, options) * 4.f; - at::Tensor t = at::matmul(a, b); - at::Tensor ref = t + 3.f; - - at::Tensor nnc_result; - std::vector result_buf(8 * 8); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen llvm_codegen(l.root_stmt(), {Result}); - - llvm_codegen.call({result_buf}); - nnc_result = at::from_blob(result_buf.data(), {8, 8}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -#endif - - SimpleIREvaluator ir_eval(l.root_stmt(), {Result}); - - ir_eval.call({result_buf}); - nnc_result = at::from_blob(result_buf.data(), {8, 8}, options); - ASSERT_TRUE(at::allclose(nnc_result, ref)); -} - -TEST(ExternalCall, JitCustomFusionOp) { - const char* custom_op_schema_literal = - "nnc_custom::add_mul(Tensor a, Tensor b, Tensor c) -> Tensor"; - const char* external_func_name = "nnc_add_mul"; - - auto add_mul_lowering_func = - [external_func_name]( - const std::vector& inputs, - const std::vector& output_shape, - const std::vector& output_strides, - const std::optional& output_type, - at::Device device) { - auto output_dtype = Dtype(*output_type); - torch::jit::tensorexpr::BufHandle result_buf( - "nnc_add_mul_res_buf", output_shape, output_dtype); - const torch::jit::tensorexpr::BufHandle& a = - std::get(inputs[0]); - const torch::jit::tensorexpr::BufHandle& b = - std::get(inputs[1]); - const torch::jit::tensorexpr::BufHandle& c = - std::get(inputs[1]); - torch::jit::tensorexpr::StmtPtr s = - torch::jit::tensorexpr::ExternalCall::make( - result_buf, external_func_name, {a, b, c}, {}); - return Tensor(result_buf.node(), s); - }; - - auto add_mul_external_func = [](int64_t bufs_num, - void** buf_data, - int64_t* buf_ranks, - int64_t* buf_dims, - int64_t* buf_strides, - int8_t* buf_dtypes, - int64_t args_num, - int64_t* extra_args) {}; - - torch::jit::RegisterOperators reg({Operator( - custom_op_schema_literal, - [](const Node* node) -> Operation { - return [](Stack& _stack) { - auto a = std::move(peek(_stack, 0, 3)).toTensor(); - auto b = std::move(peek(_stack, 1, 3)).toTensor(); - auto c = std::move(peek(_stack, 2, 3)).toTensor(); - drop(_stack, 3); - auto result = (a + b) * c; - pack(_stack, std::move(result)); - return 0; - }; - }, - c10::AliasAnalysisKind::FROM_SCHEMA)}); - - auto& custom_operator_set = torch::jit::tensorexpr::getCustomOperatorSet(); - custom_operator_set.insert({custom_op_schema_literal}); - - auto& te_lowering_registry = torch::jit::tensorexpr::getNNCLoweringRegistry(); - te_lowering_registry.insert( - parseSchema(custom_op_schema_literal), add_mul_lowering_func); - - auto& te_nnc_func_registry = torch::jit::tensorexpr::getNNCFunctionRegistry(); - te_nnc_func_registry[external_func_name] = add_mul_external_func; - - std::string graph_string = R"IR( - graph(%a : Float(10, 20, strides=[20, 1], device=cpu), - %b : Float(10, 20, strides=[20, 1], device=cpu), - %c : Float(10, 20, strides=[20, 1], device=cpu)): - %res : Float(10, 20, strides=[20, 1], device=cpu) = nnc_custom::add_mul(%a, %b, %c) - return (%res))IR"; - - auto graph = std::make_shared(); - torch::jit::parseIR(graph_string, graph.get()); - - std::string shape_compute_python_string = R"PY( - def computOutput(a: List[int], b: List[int], c: List[int]): - expandedSizes: List[int] = [] - dimsA = len(a) - dimsB = len(b) - dimsC = len(c) - ndim = max(dimsA, dimsB, dimsC) - for i in range(ndim): - offset = ndim - 1 - i - dimA = dimsA - 1 - offset - dimB = dimsB - 1 - offset - dimC = dimsC - 1 - offset - sizeA = a[dimA] if (dimA >= 0) else 1 - sizeB = b[dimB] if (dimB >= 0) else 1 - sizeC = a[dimC] if (dimC >= 0) else 1 - - if sizeA != sizeB and sizeB != sizeC and sizeA != 1 and sizeB != 1 and sizeC != 1: - # TODO: only assertion error is bound in C++ compilation right now - raise AssertionError( - "The size of tensor a {} must match the size of tensor b (" - "{} and c {}) at non-singleton dimension {}".format(sizeA, sizeB, sizeC, i) - ) - - expandedSizes.append(max(sizeA, sizeB, sizeC)) - - return expandedSizes - )PY"; - auto cu_ptr = torch::jit::compile(shape_compute_python_string); - torch::jit::GraphFunction* gf = - (torch::jit::GraphFunction*)&cu_ptr->get_function("computOutput"); - ASSERT_TRUE(gf); - -#ifdef TORCH_ENABLE_LLVM - auto static_graph_case = graph->copy(); - FuseTensorExprs(static_graph_case, 1); - torch::jit::testing::FileCheck() - .check("prim::TensorExprGroup_") - ->check("nnc_custom::add_mul") - ->run(*static_graph_case); - - auto dynamic_graph_case = graph->copy(); - auto custom_op = torch::jit::getOperatorForLiteral(custom_op_schema_literal); - ASSERT_TRUE(custom_op); - torch::jit::RegisterShapeComputeGraphForSchema( - custom_op->schema(), gf->graph()); - FuseTensorExprs(dynamic_graph_case, 1, false, true); - torch::jit::testing::FileCheck() - .check("prim::TensorExprGroup_") - ->check("nnc_custom::add_mul") - ->run(*dynamic_graph_case); -#else - torch::jit::testing::FileCheck().check("nnc_custom::add_mul")->run(*graph); -#endif -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_graph_opt.cpp b/test/cpp/tensorexpr/test_graph_opt.cpp deleted file mode 100644 index aed73d09d14d..000000000000 --- a/test/cpp/tensorexpr/test_graph_opt.cpp +++ /dev/null @@ -1,319 +0,0 @@ -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -class GraphOpt : public ::testing::Test { - public: - void SetUp() override { - old_cat_wo_conditionals_ = getCatWoConditionals(); - getCatWoConditionals() = true; - } - - void TearDown() override { - getCatWoConditionals() = old_cat_wo_conditionals_; - } - - private: - bool old_cat_wo_conditionals_; -}; - -TEST_F(GraphOpt, OptimizeCat) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cpu), - %z : Float(30, strides=[1], device=cpu)): - %dim : int = prim::Constant[value=0]() - %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) - %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) - %5 : Float(60, strides=[1], device=cpu) = aten::log(%cat) - return (%5))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - g->lint(); - - TensorExprKernel kernel(g); - - // The `aten::log` op must be moved to the inputs of `aten::cat`. - testing::FileCheck() - .check("aten::log") - ->check("aten::log") - ->check("aten::log") - ->check("aten::cat") - ->check_not("aten::log") - ->run(*kernel.graph()); - - auto x = at::rand({10}, at::kFloat); - auto y = at::rand({20}, at::kFloat); - auto z = at::rand({30}, at::kFloat); - auto ref = at::log(at::cat({x, y, z}, 0)); - - std::vector inputs = {x, y, z}; - std::vector stack = fmap(inputs); - kernel.run(stack); - auto out = stack[0].toTensor(); - ASSERT_EQ(out.sizes(), ref.sizes()); - ASSERT_EQ(out.dtype(), ref.dtype()); - ASSERT_TRUE(at::allclose(out, ref)); -#endif -} - -TEST_F(GraphOpt, OptimizeCat2) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cpu), - %z : Float(30, strides=[1], device=cpu)): - %dim : int = prim::Constant[value=0]() - %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) - %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) - %5 : Float(60, strides=[1], device=cpu) = aten::log(%cat) - %6 : Float(60, strides=[1], device=cpu) = aten::tanh(%5) - return (%6))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - g->lint(); - - TensorExprKernel kernel(g); - - // The `aten::log` and `aten::tanh` ops must be moved to the inputs of - // `aten::cat`. - testing::FileCheck() - .check("aten::log") - ->check("aten::log") - ->check("aten::log") - ->check("aten::tanh") - ->check("aten::tanh") - ->check("aten::tanh") - ->check("aten::cat") - ->check_not("aten::log") - ->check_not("aten::tanh") - ->run(*kernel.graph()); - - auto x = at::rand({10}, at::kFloat); - auto y = at::rand({20}, at::kFloat); - auto z = at::rand({30}, at::kFloat); - auto ref = at::tanh(at::log(at::cat({x, y, z}, 0))); - - std::vector inputs = {x, y, z}; - std::vector stack = fmap(inputs); - kernel.run(stack); - auto out = stack[0].toTensor(); - ASSERT_EQ(out.sizes(), ref.sizes()); - ASSERT_EQ(out.dtype(), ref.dtype()); - ASSERT_TRUE(at::allclose(out, ref)); -#endif -} - -TEST_F(GraphOpt, OptimizeCat3) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%a : Float(60, strides=[1], device=cpu), - %x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cpu), - %z : Float(30, strides=[1], device=cpu)): - %dim : int = prim::Constant[value=0]() - %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) - %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) - %5 : Float(60, strides=[1], device=cpu) = aten::tanh(%cat) - %6 : Float(60, strides=[1], device=cpu) = aten::mul(%a, %5) - return (%6))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - g->lint(); - - TensorExprKernel kernel(g); - - // The `aten::tanh` op must be moved to the inputs of `aten::cat`. - // But the `aten::mul` op must not be moved since it is not a single-tensor - // op (it has 2 tensor inputs). - testing::FileCheck() - .check("aten::tanh") - ->check("aten::tanh") - ->check("aten::tanh") - ->check("aten::cat") - ->check("aten::mul") - ->check_not("aten::tanh") - ->run(*kernel.graph()); - - auto a = at::rand({60}, at::kFloat); - auto x = at::rand({10}, at::kFloat); - auto y = at::rand({20}, at::kFloat); - auto z = at::rand({30}, at::kFloat); - auto ref = at::tanh(at::cat({x, y, z}, 0)) * a; - - std::vector inputs = {a, x, y, z}; - std::vector stack = fmap(inputs); - kernel.run(stack); - auto out = stack[0].toTensor(); - ASSERT_EQ(out.sizes(), ref.sizes()); - ASSERT_EQ(out.dtype(), ref.dtype()); - ASSERT_TRUE(at::allclose(out, ref)); -#endif -} - -TEST_F(GraphOpt, OptimizeCatWithTypePromotionInUser) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%x : Int(10, strides=[1], device=cpu), - %y : Int(20, strides=[1], device=cpu), - %z : Int(30, strides=[1], device=cpu)): - %dim : int = prim::Constant[value=0]() - %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) - %cat : Int(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) - %5 : Float(60, strides=[1], device=cpu) = aten::tanh(%cat) - return (%5))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - g->lint(); - - TensorExprKernel kernel(g); - - // The `aten::tanh` op must be moved to the inputs of `aten::cat`. - // The scalar type of the inputs to `cat` should now be `Float` since they - // are the result of `tanh` which does the type promotion. - testing::FileCheck() - .check("aten::tanh") - ->check("aten::tanh") - ->check("aten::tanh") - ->check("aten::cat") - ->check_not("aten::tanh") - ->run(*kernel.graph()); - - auto x = at::randint(std::numeric_limits::max(), {10}, at::kInt); - auto y = at::randint(std::numeric_limits::max(), {20}, at::kInt); - auto z = at::randint(std::numeric_limits::max(), {30}, at::kInt); - auto ref = at::tanh(at::cat({x, y, z}, 0)); - - std::vector inputs = {x, y, z}; - std::vector stack = fmap(inputs); - kernel.run(stack); - auto out = stack[0].toTensor(); - ASSERT_EQ(out.sizes(), ref.sizes()); - ASSERT_EQ(out.dtype(), ref.dtype()); - ASSERT_TRUE(at::allclose(out, ref)); -#endif -} - -TEST_F(GraphOpt, OptimizeCatWithTypePromotionInCat) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cpu), - %z : Double(30, strides=[1], device=cpu)): - %dim : int = prim::Constant[value=0]() - %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) - %cat : Double(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) - %5 : Double(60, strides=[1], device=cpu) = aten::log(%cat) - return (%5))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - g->lint(); - - TensorExprKernel kernel(g); - - // No transformation should have happened because the `aten::cat` op performs - // type promotion. This case is currently not handled. - testing::FileCheck() - .check("aten::cat") - ->check("aten::log") - ->check_not("aten::cat") - ->check_not("aten::log") - ->run(*kernel.graph()); -#endif -} - -TEST_F(GraphOpt, OptimizeCatNoSingleTensorElementwiseOp) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%0 : Float(60, strides=[1], device=cpu), - %x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cpu), - %z : Float(30, strides=[1], device=cpu)): - %dim : int = prim::Constant[value=0]() - %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) - %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) - %5 : Float(60, strides=[1], device=cpu) = aten::mul(%0, %cat) - return (%5))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - g->lint(); - - TensorExprKernel kernel(g); - - // No transformation is expected since the consumers of cat are not - // single-tensor element-wise ops. - testing::FileCheck() - .check("aten::cat") - ->check("aten::mul") - ->check_not("aten::cat") - ->check_not("aten::mul") - ->run(*kernel.graph()); -#endif -} - -TEST_F(GraphOpt, OptimizeCatNoSingleTensorElementwiseOp2) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%0 : Float(60, strides=[1], device=cpu), - %1 : Float(60, strides=[1], device=cpu), - %x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cpu), - %z : Float(30, strides=[1], device=cpu)): - %one : int = prim::Constant[value=1]() - %dim : int = prim::Constant[value=0]() - %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) - %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) - %5 : Float(60, strides=[1], device=cpu) = aten::mul(%0, %cat) - %6 : Float(60, strides=[1], device=cpu) = aten::add(%5, %1, %one) - return (%6))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - g->lint(); - - TensorExprKernel kernel(g); - - // No transformation is expected since the consumers of cat are not - // single-tensor element-wise ops. - testing::FileCheck() - .check("aten::cat") - ->check("aten::mul") - ->check("aten::add") - ->check_not("aten::cat") - ->check_not("aten::mul") - ->check_not("aten::add") - ->run(*kernel.graph()); -#endif -} - -TEST_F(GraphOpt, AOTGraphPrepPasses) { - const auto graph_string = R"IR( - graph(%x, %y, %z, %i : int): - %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) - return (%xyz_list, %i))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - removeGraphOutput(g, 1); - replaceListOutputWithTuple(g); - LowerAllTuples(g); - - testing::FileCheck().check("return (%x, %y, %z)")->run(*g); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_ir_printer.cpp b/test/cpp/tensorexpr/test_ir_printer.cpp deleted file mode 100644 index 4d2f8c6e906e..000000000000 --- a/test/cpp/tensorexpr/test_ir_printer.cpp +++ /dev/null @@ -1,98 +0,0 @@ -#include - -#include -#include "test/cpp/tensorexpr/test_base.h" - -#include -#include -#include -#include -#include -#include - -#include -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -TEST(IRPrinter, BasicValueTest) { - ExprHandle a = IntImm::make(2), b = IntImm::make(3); - ExprHandle c = Add::make(a, b); - - std::stringstream ss; - ss << c; - ASSERT_EQ(ss.str(), "2 + 3"); -} - -TEST(IRPrinter, BasicValueTest02) { - ExprHandle a(2.0f); - ExprHandle b(3.0f); - ExprHandle c(4.0f); - ExprHandle d(5.0f); - ExprHandle f = (a + b) - (c + d); - - std::stringstream ss; - ss << f; - ASSERT_EQ(ss.str(), "(2.f + 3.f) - (4.f + 5.f)"); -} - -TEST(IRPrinter, BasicValueTest03) { - ExprHandle a(3.402823466385289e+38f); - ExprHandle b(-3.402823466385289e+38f); - std::stringstream ss; - ss << a << ", " << b; - ASSERT_EQ(ss.str(), "3.402823466385289e+38f, -3.402823466385289e+38f"); -} - -TEST(IRPrinter, CastTest) { - VarHandle x("x", kHalf); - VarHandle y("y", kFloat); - ExprHandle body = ExprHandle(2.f) + - (Cast::make(kFloat, x) * ExprHandle(3.f) + ExprHandle(4.f) * y); - - std::stringstream ss; - ss << body; - ASSERT_EQ(ss.str(), "2.f + (float(x) * 3.f + 4.f * y)"); -} - -TEST(IRPrinter, FunctionName) { - int M = 4; - int N = 20; - - Tensor producer = Compute( - "producer", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return m * n; - }); - - Tensor chunk_0 = Compute( - "chunk_0", {M, N / 2}, [&](const ExprHandle& m, const ExprHandle& n) { - return producer.load(m, n); - }); - - Tensor chunk_1 = Compute( - "chunk_1", {M, N / 2}, [&](const ExprHandle& m, const ExprHandle& n) { - return producer.load(m, n + ExprHandle(N / 2)); - }); - - Tensor consumer = Compute( - "consumer", {M, N / 2}, [&](const ExprHandle& i, const ExprHandle& j) { - return i * chunk_1.load(i, j); - }); - - LoopNest l({chunk_0, chunk_1, consumer}); - auto body = LoopNest::sanitizeNames(l.root_stmt()); - - std::stringstream ss; - ss << *body; - - const std::string& verification_pattern = - R"IR( - # CHECK: for (int i_2 - # CHECK: for (int j_2 - # CHECK: consumer[i_2, j_2] = i_2 * (chunk_1[i_2, j_2])IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, ss.str()); -} -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_ir_verifier.cpp b/test/cpp/tensorexpr/test_ir_verifier.cpp deleted file mode 100644 index 886213ea9c76..000000000000 --- a/test/cpp/tensorexpr/test_ir_verifier.cpp +++ /dev/null @@ -1,191 +0,0 @@ -#include - -#include -#include "test/cpp/tensorexpr/test_base.h" - -#include -#include -#include -#include -#include -#include - -#include -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -TEST(IRVerifier, BitwiseOps) { - VarPtr X = alloc("x", kInt); - VarPtr Y = alloc("y", kFloat); - { - auto a = alloc(X, Y); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } - { - auto a = alloc(X, Y); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } - { - auto a = alloc(X, Y); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } - { - auto a = alloc(X, Y); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } - { - auto a = alloc(X, Y); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } -} - -TEST(IRVerifier, CompareSelect) { - ExprPtr X = alloc(1); - ExprPtr Y = alloc(3.14f); - { - auto a = alloc(X, X, X, Y, kEQ); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } - { - auto a = alloc(X, Y, X, X, kEQ); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } -} - -TEST(IRVerifier, Ramp) { - VarPtr I = alloc("i", kInt); - VarPtr J = alloc("j", kFloat); - { - auto a = alloc(I, J, 4); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } -} - -TEST(IRVerifier, Load) { - VarPtr I = alloc("i", kInt); - VarPtr J = alloc("j", kLong); - VarPtr K = alloc("k", kFloat); - BufPtr B = alloc( - "b", - std::vector({alloc(10), alloc(20)}), - kFloat); - { - // Indices with different int dtypes (kInt, kLong) are ok - auto a = alloc(B, std::vector({I, J})); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_NO_THROW(verify(a)); - } - { - // Float index - auto a = alloc(B, std::vector({K, K})); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } - { - // Multilanes are only allowed in flattened indices - auto multilane_index = alloc(I, alloc(1), 4); - auto a = alloc(B, std::vector({I, multilane_index})); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } -} - -TEST(IRVerifier, IfThenElse) { - VarPtr I = alloc("i", kInt); - VarPtr J = alloc("j", kLong); - VarPtr K = alloc("k", kFloat); - { - // Condition must be integral - auto a = alloc(K, I, I); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } - { - // Dtypes of true and false exprs must match - auto a = alloc(I, I, J); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } - { - // Can't have multiple lanes in condition expr - auto a = alloc(alloc(I, 4), I, I); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } -} - -TEST(IRVerifier, For) { - VarPtr I = alloc("i", kInt); - VarPtr J = alloc("j", kInt); - StmtPtr body = alloc(std::vector({})); - { - // Can't have nullptr as a Var - auto a = alloc(nullptr, I, J, body); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_ANY_THROW(verify(a)); - } -} - -TEST(IRVerifier, Block) { - VarPtr I = alloc("i", kInt); - BufPtr B = alloc("B", std::vector({alloc(10)}), kInt); - { - StmtPtr store = alloc(B, std::vector({I}), I); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - StmtPtr block1 = alloc(std::vector({store})); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - StmtPtr block2 = alloc(std::vector({store})); - // Stmt can't have multiple parents, thus inserting it into several blocks - // is illegal - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(block2)); - } -} - -TEST(IRVerifier, Store) { - VarPtr I = alloc("i", kInt); - VarPtr J = alloc("j", kLong); - VarPtr K = alloc("k", kFloat); - BufPtr B = alloc( - "b", - std::vector({alloc(10), alloc(20)}), - kFloat); - { - // Indices with different int dtypes (kInt, kLong) are ok - auto a = alloc(B, std::vector({I, J}), K); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_NO_THROW(verify(a)); - } - { - // Float index - auto a = alloc(B, std::vector({K, K}), K); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } - { - // Multilanes are only allowed in flattened indices - auto multilane_index = alloc(I, alloc(1), 4); - auto a = alloc(B, std::vector({I, multilane_index}), K); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } - { - // Value and buf dtypes mismatch - auto a = alloc(B, std::vector({I}), I); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) - EXPECT_ANY_THROW(verify(a)); - } -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_kernel.cpp b/test/cpp/tensorexpr/test_kernel.cpp deleted file mode 100644 index dc67928b111a..000000000000 --- a/test/cpp/tensorexpr/test_kernel.cpp +++ /dev/null @@ -1,2133 +0,0 @@ -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { - -using namespace torch::indexing; -using namespace torch::jit::tensorexpr; - -class Kernel : public ::testing::Test { - public: - void SetUp() override { - getTEMustUseLLVMOnCPU() = false; - } -}; - -TEST_F(Kernel, ParallelExternalCallBuf) { - const auto graph_string = R"IR( - graph(%0 : Float(1000, 5000, strides=[5000, 1], device=cpu), - %1 : Float(1000, 5000, strides=[5000, 1], device=cpu), - %2 : Float(5000, 1000, strides=[5000, 1], device=cpu)): - %3 : Float(1000, 5000, strides=[5000, 1], device=cpu) = aten::mul(%0, %1) - %4 : Float(1000, 5000, strides=[5000, 1], device=cpu) = aten::matmul(%3, %2) - return (%4))IR"; - auto graph = std::make_shared(); - torch::jit::parseIR(graph_string, graph.get()); -#ifdef TORCH_ENABLE_LLVM - const std::string& verification_pattern = - R"IR( -# CHECK: for (int64_t i = 0ll; i < 5000ll; i++) /* parallel */{)IR"; - - TensorExprKernel k(graph); - StmtPtr s = k.getCodeGenStmt(); - std::ostringstream oss; - oss << *s; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -#endif -} - -TEST_F(Kernel, InliningIntermediates) { - // here, each mul has only one use, so it should be completely inlined - { - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), - %1 : Float(5, 3, strides=[3, 1], device=cpu)): - %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) - %one : int = prim::Constant[value=1]() - %4 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) - %5: Float(5, 3, strides=[3, 1]) = aten::add(%4, %1, %one) - return (%5))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - TensorExprKernel k(graph); - auto stmt = k.getCodeGenStmt(); - std::ostringstream oss; - oss << *stmt; - torch::jit::testing::FileCheck().check_not("aten_mul")->run(oss.str()); - } - { - const auto graph_template = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=${device}), - %1 : Float(5, 3, strides=[3, 1], device=${device})): - %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) - %one : int = prim::Constant[value=1]() - %3 : Float(5, 3, strides=[3, 1]) = aten::sub(%0, %2, %one) - %4 : Float(5, 3, strides=[3, 1]) = aten::add(%3, %0, %one) - %5 : Float(5, 3, strides=[3, 1]) = aten::div(%3, %0) - return (%4, %5))IR"; - for (bool use_cuda : {false, true}) { - if (!torch::cuda::is_available() && use_cuda) { - continue; - } - - at::jit::TemplateEnv env; - env.s("device", use_cuda ? "cuda:0" : "cpu"); - const auto graph_string = format(graph_template, env); - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - TensorExprKernel k(graph); - auto stmt = k.getCodeGenStmt(); - std::ostringstream oss; - oss << *stmt; - // aten_mul only has one use, inlined completely - torch::jit::testing::FileCheck().check_not("aten_mul")->run(oss.str()); - - // aten_sub should be removed by the CUDA backend by metavar rewriting - // and by the CPU backend by horizontal fusion. - torch::jit::testing::FileCheck().check_not("aten_sub")->run(oss.str()); - } - } -} - -TEST_F(Kernel, PreAllocIntermediateBufs) { - const auto graph_string = R"IR( -graph(%a.1 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu), - %b.1 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu)): - %2 : int = prim::Constant[value=1]() - %c.2 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu) = aten::matmul(%a.1, %b.1) # test_matmul.py:12:12 - %3 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu) = aten::add(%a.1, %c.2, %2) # test_matmul.py:13:15 - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); - auto o = at::zeros({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = at::matmul(a, b) + a; - TensorExprKernel k(graph, {}, {}, true); - - std::vector inputs = {a, b}; - auto stmt = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *stmt; - - // Check whether the intermediate buffer has been added to constants - auto constants = k.getConstantDescriptors(); - ASSERT_EQ(constants.size(), 1); - - // Check the IR we produced - torch::jit::testing::FileCheck().check_not("Alloc")->run(oss.str()); - torch::jit::testing::FileCheck().check_not("Free")->run(oss.str()); - - // Check correctness - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); -} - -TEST_F(Kernel, _1) { - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), - %1 : Float(5, 3, strides=[3, 1], device=cpu)): - %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) - %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = a * (a * b); - TensorExprKernel k(graph); - std::vector inputs = {a, b}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK-NEXT: for -# CHECK-NOT: for)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - for (size_t i = 0; i < 5 * 3; i++) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } -} - -TEST_F(Kernel, _2) { - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), - %1 : Float(5, 3, strides=[1, 5], device=cpu)): - %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) - %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = - at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1); - auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = a * (a * b); - TensorExprKernel k(graph); - std::vector inputs = {a, b}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK-NEXT: for -# CHECK-NOT: for)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - for (size_t i = 0; i < 5 * 3; i++) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } -} - -TEST_F(Kernel, _3) { - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), - %1 : Float(5, 3, strides=[12, 2], device=cpu)): - %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) - %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({10, 6}, TensorOptions(kCPU).dtype(at::kFloat)) - .index({Slice(None, None, 2), Slice(None, None, 2)}); - auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = a * (a * b); - TensorExprKernel k(graph); - std::vector inputs = {a, b}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK-NEXT: for -# CHECK-NOT: for)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - for (size_t i = 0; i < 5 * 3; i++) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } -} - -TEST_F(Kernel, Huge) { - const auto graph_string = R"IR( - graph(%x.1 : Float(4000000000, strides=[1], requires_grad=0, device=cpu)): - %1 : int = prim::Constant[value=0]() - %2 : Float(1, 4000000000, strides=[4000000000, 1], requires_grad=0, device=cpu) = aten::unsqueeze(%x.1, %1) - %3 : Float(1, 4000000000, strides=[4000000000, 1], requires_grad=0, device=cpu) = aten::relu(%2) - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - TensorExprKernel k(graph); - std::ostringstream oss; - oss << *k.getCodeGenStmt(); - // The 4000000000 iterations loop will be split into 500000000 x 8 and the - // outer loop will be parallel. If LLVM is not present, it will not be split, - // and to cover both of these cases we're looking for 00000000ll; in the - // output. - const std::string& verification_pattern = R"IR(# CHECK: 00000000ll;)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST_F(Kernel, ParallelStrided) { - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, 40005, strides=[120015, 40005, 1], device=cpu), - %1 : Float(5, 3, 40005, strides=[960120, 160020, 2], device=cpu)): - %2 : Float(5, 3, 40005, strides=[120015, 40005, 1]) = aten::mul(%0, %1) - %3 : Float(5, 3, 40005, strides=[120015, 40005, 1]) = aten::mul(%0, %2) - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({5, 3, 40005}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({10, 6, 80010}, TensorOptions(kCPU).dtype(at::kFloat)) - .index( - {Slice(None, None, 2), - Slice(None, None, 2), - Slice(None, None, 2)}); - auto ref = a * (a * b); - auto o = at::zeros_like(ref); - TensorExprKernel k(graph); - std::vector inputs = {a, b}; - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - for (size_t i = 0; i < 5 * 3; i++) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } -} - -TEST_F(Kernel, DISABLED_Shape_Inference) { - // disabled: doesn't do stride propagation, and isn't being used currently - - // Test TensorExpr shape inference capabilities: it should only require shapes - // for the inputs - { - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), - %1 : Float(5, 3, strides=[12, 2], device=cpu)): - %2 : Tensor = aten::mul(%0, %1) - %3 : Tensor = aten::mul(%0, %2) - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({10, 6}, TensorOptions(kCPU).dtype(at::kFloat)) - .index({Slice(None, None, 2), Slice(None, None, 2)}); - auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = a * (a * b); - TensorExprKernel k(graph); - std::vector inputs = {a, b}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK-NEXT: for -# CHECK-NOT: for)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - for (size_t i = 0; i < 5 * 3; i++) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } - } - { - const auto graph_string = R"IR( - graph(%0 : Float(8, 8, strides=[8, 1], device=cpu), - %1 : Float(8, 8, strides=[8, 1], device=cpu)): - %2 : Tensor = aten::mul(%0, %1) - %3 : Tensor, %4 : Tensor = prim::ConstantChunk[dim=1,chunks=2](%2) - %r : Tensor = aten::mul(%3, %4) - return (%r))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); - auto o = at::zeros({8, 4}, TensorOptions(kCPU).dtype(at::kFloat)); - auto t = torch::chunk(a * b, 2, 1); - auto ref = t[0] * t[1]; - TensorExprKernel k(graph); - std::vector inputs = {a, b}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - TORCH_CHECK_EQ(o.sizes()[0], 8); - TORCH_CHECK_EQ(o.sizes()[1], 4); - for (size_t i = 0; i < 8 * 4; i++) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } - } - { - // Test that shape inference handles aten::unsqueeze - - const auto graph_string = R"IR( - graph(%a : Float(4, 2, strides=[2, 1], device=cpu), - %b : Float(4, 3, 2, strides=[6, 2, 1], device=cpu), - %c : Float(3, 2, 2, strides=[4, 2, 1], device=cpu)): - %one : int = prim::Constant[value=1]() - %minus_one : int = prim::Constant[value=-1]() - %three : int = prim::Constant[value=3]() - %minus_four : int = prim::Constant[value=-4]() - %a1 : Tensor = aten::unsqueeze(%a, %one) # new size: [4,1,2] - %a2 : Tensor = aten::unsqueeze(%a1, %minus_one) # new size: [4,1,2,1] - %b1 : Tensor = aten::unsqueeze(%b, %three) # new size: [4,3,2,1] - %c1 : Tensor = aten::unsqueeze(%c, %minus_four) # new size: [1,3,2,2] - %ab : Tensor = aten::mul(%a2, %b1) # expected size: [4,3,2,1] - %abc : Tensor = aten::mul(%ab, %c1) # expected size: [4,3,2,2] - return (%abc))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({4, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({4, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto c = at::rand({3, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto o = at::zeros({4, 3, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = at::unsqueeze(at::unsqueeze(a, 1), -1) * at::unsqueeze(b, 3) * - at::unsqueeze(c, -4); - - TensorExprKernel k(graph); - std::vector inputs = {a, b, c}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK-NEXT: for -# CHECK-NEXT: for -# CHECK-NEXT: for -# CHECK-NEXT: aten_mul)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - - // Check sizes - TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); - size_t num_el = 1; - for (const auto idx : c10::irange(ref.sizes().size())) { - TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); - num_el *= ref.sizes()[idx]; - } - - // Check the contents - for (const auto i : c10::irange(num_el)) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } - } - { - // Test that shape inference handles aten::cat - - const auto graph_string = R"IR( - graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), - %b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu), - %c : Float(5, 9, 2, strides=[18, 2, 1], device=cpu)): - %dim : int = prim::Constant[value=1]() - %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) - %r : Tensor = aten::cat(%inputs, %dim) # new size: [5,19,2] - return (%r))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({5, 7, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto c = at::rand({5, 9, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto o = at::zeros({5, 19, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = at::cat({a, b, c}, 1); - - TensorExprKernel k(graph); - std::vector inputs = {a, b, c}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK-NEXT: for -# CHECK-NEXT: for -# CHECK-NEXT: aten_cat)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - - // Check sizes - TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); - size_t num_el = 1; - for (const auto idx : c10::irange(ref.sizes().size())) { - TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); - num_el *= ref.sizes()[idx]; - } - - // Check the contents - for (const auto i : c10::irange(num_el)) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } - } - { - // Test that we throw an error when input list for aten::cat is empty - - const auto graph_string = R"IR( - graph(): - %dim : int = prim::Constant[value=1]() - %inputs : Tensor[] = prim::ListConstruct() - %r : Tensor = aten::cat(%inputs, %dim) - return (%r))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - auto compile = [&]() { - TensorExprKernel k(graph); - k.getCodeGenStmt(); - }; - ASSERT_THROWS_WITH(compile(), "Empty input list is passed to aten::cat"); - } - { - // Test that we throw an error when 'dim' passed to aten::cat is invalid - - const auto ir_dim_99 = R"IR( - graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), - %b : Float(5, 3, 2, strides=[6, 2, 1], device=cpu)): - %dim : int = prim::Constant[value=99]() - %inputs : Tensor[] = prim::ListConstruct(%a, %b) - %r : Float(5, 3, 2, strides=[6, 2, 1], device=cpu) = aten::cat(%inputs, %dim) - return (%r))IR"; - const auto ir_dim_minus_6 = R"IR( - graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), - %b : Float(5, 3, 2, strides=[6, 2, 1], device=cpu)): - %dim : int = prim::Constant[value=-6]() - %inputs : Tensor[] = prim::ListConstruct(%a, %b) - %r : Float(5, 3, 2, strides=[6, 2, 1], device=cpu) = aten::cat(%inputs, %dim) - return (%r))IR"; - - auto compile = [](const std::string& graph_string) { - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - TensorExprKernel k(graph); - k.getCodeGenStmt(); - }; - ASSERT_THROWS_WITH(compile(ir_dim_99), "Invalid index"); - ASSERT_THROWS_WITH(compile(ir_dim_minus_6), "Invalid index"); - } -} - -TEST_F(Kernel, CatInputTypesPromotion) { - { - // Test that we properly promote input types for aten::cat - - const auto graph_string = R"IR( - graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), - %b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu), - %c : Double(5, 9, 2, strides=[18, 2, 1], device=cpu)): - %dim : int = prim::Constant[value=1]() - %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) - %r : Double(5, 19, 2, strides=[38, 2, 1]) = aten::cat(%inputs, %dim) - return (%r))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({5, 7, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto c = at::rand({5, 9, 2}, TensorOptions(kCPU).dtype(at::kDouble)); - auto ref = at::cat({a, b, c}, 1); - - TensorExprKernel k(graph); - std::vector inputs = {a, b, c}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK-NEXT: for -# CHECK-NEXT: for -# CHECK-NEXT: aten_cat)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - - // Check sizes - TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); - TORCH_CHECK_EQ(o.dtype(), ref.dtype()); - size_t num_el = 1; - for (const auto idx : c10::irange(ref.sizes().size())) { - TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); - num_el *= ref.sizes()[idx]; - } - - // Check the contents - for (const auto i : c10::irange(num_el)) { - TORCH_CHECK_EQ(((double*)o.data_ptr())[i], ((double*)ref.data_ptr())[i]); - } - } -} - -TEST_F(Kernel, ToDType) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%x.1 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu)): - %1 : NoneType = prim::Constant() - %2 : bool = prim::Constant[value=0]() - %3 : int = prim::Constant[value=6]() - %4 : int = prim::Constant[value=15]() - %5 : int = prim::Constant[value=5]() - %6 : bool = prim::Constant[value=1]() - %y.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::sigmoid(%x.1) - %z.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::_autocast_to_reduced_precision(%y.3, %6, %6, %5, %4) - %h.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::_autocast_to_full_precision(%z.3, %6, %6) - %i.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%h.3, %3, %2, %2, %1) - %j.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%i.3, %4, %2, %2, %1) - %k.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%j.3, %3, %2, %2, %1) - return (%k.3))IR"; - - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - TensorExprKernel k(graph); - StmtPtr s = k.getCodeGenStmt(); - std::ostringstream oss; - oss << *s; - - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK-NEXT: for -# CHECK-NEXT: aten_to -# CHECK-NEXT: } -# CHECK-NEXT: })IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto a = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kBFloat16)); - auto ref = - at::_to_copy(at::sigmoid(a), TensorOptions(kCPU).dtype(at::kFloat)); - - std::vector inputs = {a}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - ASSERT_EQ(o.sizes(), ref.sizes()); - ASSERT_EQ(o.dtype(), ref.dtype()); - ASSERT_TRUE(at::allclose(o, ref, 4E-3, 4E-3)); -#endif -} - -TEST_F(Kernel, CatAndInlineWithAConstantDim) { - const auto graph_string = R"IR( - graph(%0 : Float(1, 512, strides=[1024, 1], requires_grad=0, device=cpu), - %1 : Float(1, 512, strides=[1024, 1], requires_grad=0, device=cpu)): - %2 : bool = prim::Constant[value=0]() - %3 : int = prim::Constant[value=1]() - %4 : Tensor[] = prim::ListConstruct(%0, %1) - %5 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::cat(%4, %3) - %6 : Tensor[] = prim::ListConstruct(%5) - %7 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::cat(%6, %3) - %8 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::_cast_Float(%7, %2) - return (%8, %7))IR"; - - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - TensorExprKernel k(graph); - - auto a = at::rand({1, 512}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({1, 512}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = at::_cast_Float(at::cat({a, b}, 1), 0); - - std::vector inputs = {a, b}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - ASSERT_EQ(o.sizes(), ref.sizes()); - ASSERT_EQ(o.dtype(), ref.dtype()); - ASSERT_TRUE(at::allclose(o, ref)); -} - -TEST_F(Kernel, CatWithEmptyInputs) { - bool curr_cat_wo_conditionals = getCatWoConditionals(); - for (auto cat_wo_conditionals : {true, false}) { - getCatWoConditionals() = cat_wo_conditionals; - const auto graph_string = R"IR( - graph(%0 : Float(0, 64, strides=[64, 1], requires_grad=0, device=cpu), - %1 : Float(10, 64, strides=[64, 1], requires_grad=0, device=cpu)): - %3 : int = prim::Constant[value=0]() - %6 : Float(0, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::tanh(%0) - %7 : Float(10, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::tanh(%1) - %10 : Tensor[] = prim::ListConstruct(%6, %7) - %11 : Float(10, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::cat(%10, %3) - return (%11))IR"; - - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - TensorExprKernel k(graph); - - auto a = at::rand({0, 64}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({10, 64}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = at::cat({at::tanh(a), at::tanh(b)}, 0); - - std::vector inputs = {a, b}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - ASSERT_EQ(o.sizes(), ref.sizes()); - ASSERT_EQ(o.dtype(), ref.dtype()); - ASSERT_TRUE(at::allclose(o, ref)); - } - getCatWoConditionals() = curr_cat_wo_conditionals; -} - -TEST_F(Kernel, CatWoConditionals) { - bool old_cat_wo_conditionals = getCatWoConditionals(); - getCatWoConditionals() = true; - const auto graph_string = R"IR( - graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), - %b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu), - %c : Float(5, 9, 2, strides=[18, 2, 1], device=cpu)): - %dim : int = prim::Constant[value=1]() - %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) - %r : Float(5, 19, 2, strides=[38, 2, 1]) = aten::cat(%inputs, %dim) - return (%r))IR"; - - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - TensorExprKernel k(graph); - StmtPtr s = k.getCodeGenStmt(); - std::ostringstream oss; - oss << *s; - - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK: for -# CHECK: for -# CHECK: aten_cat -# CHECK: for -# CHECK: for -# CHECK: aten_cat -# CHECK: for -# CHECK: for -# CHECK: aten_cat)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({5, 7, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto c = at::rand({5, 9, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = at::cat({a, b, c}, 1); - - std::vector inputs = {a, b, c}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - - // Check sizes - TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); - TORCH_CHECK_EQ(o.dtype(), ref.dtype()); - size_t num_el = 1; - for (const auto idx : c10::irange(ref.sizes().size())) { - TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); - num_el *= ref.sizes()[idx]; - } - - // Check the contents - for (const auto i : c10::irange(num_el)) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } - getCatWoConditionals() = old_cat_wo_conditionals; -} - -TEST_F(Kernel, OptimizeConditionals) { - bool old_cat_wo_conditionals = getCatWoConditionals(); - bool old_opt_conditionals = getOptConditionals(); - getCatWoConditionals() = false; - getOptConditionals() = true; - const auto graph_string = R"IR( - graph(%a : Float(5, 3, strides=[3, 1], device=cpu), - %b : Float(5, 7, strides=[7, 1], device=cpu), - %c : Float(5, 9, strides=[9, 1], device=cpu)): - %dim : int = prim::Constant[value=1]() - %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) - %r : Float(5, 19, strides=[19, 1]) = aten::cat(%inputs, %dim) - %t : Float(5, 19, strides=[19, 1]) = aten::relu(%r) - return (%t))IR"; - - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - TensorExprKernel k(graph); - StmtPtr s = k.getCodeGenStmt(); - std::ostringstream oss; - oss << *s; - - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK-NEXT: for -# CHECK-NEXT: aten_relu -# CHECK: for -# CHECK-NEXT: aten_relu -# CHECK: for -# CHECK-NEXT: aten_relu -# CHECK-NOT: Allocate -# CHECK-NOT: Free)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto b = at::rand({5, 7}, TensorOptions(kCPU).dtype(at::kFloat)); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto c = at::rand({5, 9}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = at::relu(at::cat({a, b, c}, 1)); - - std::vector inputs = {a, b, c}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - - // Check sizes - TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); - TORCH_CHECK_EQ(o.dtype(), ref.dtype()); - size_t num_el = 1; - for (const auto idx : c10::irange(ref.sizes().size())) { - TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); - num_el *= ref.sizes()[idx]; - } - - // Check the contents - for (const auto i : c10::irange(num_el)) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } - getOptConditionals() = old_opt_conditionals; - getCatWoConditionals() = old_cat_wo_conditionals; -} - -namespace { - -std::string dtypeConstant(ScalarType scalar_type) { - if (scalar_type == ScalarType::Undefined) { - return "None = prim::Constant()"; - } else { - at::jit::TemplateEnv env_dtype; - env_dtype.d("scalar_type", static_cast(scalar_type)); - return format("int = prim::Constant[value=${scalar_type}]()", env_dtype); - } -} - -at::Tensor iotaTensor(IntArrayRef sizes, const at::TensorOptions& options) { - int64_t numel = std::accumulate( - sizes.begin(), - sizes.end(), - 1, - // NOLINTNEXTLINE(modernize-use-transparent-functors) - std::multiplies()); - std::vector values(numel); - std::iota(values.begin(), values.end(), 0); - auto a = at::tensor(values, options); - return a.reshape(sizes); -} - -} // namespace - -TEST_F(Kernel, SumAllAxes) { - // Test lowering of sum on all axes. - const auto graph_template = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)): - %1 : ${dtype} - %2 : ${out_dtype}(requires_grad=0, device=cpu) = aten::sum(%0, %1) - return (%2))IR"; - auto a = iotaTensor({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - - for (auto scalar_type : {ScalarType::Undefined, ScalarType::Double}) { - at::jit::TemplateEnv env; - env.s("dtype", dtypeConstant(scalar_type)); - if (scalar_type == ScalarType::Undefined) { - env.s("out_dtype", "Float"); - } else { - env.s("out_dtype", "Double"); - } - const auto graph_string = format(graph_template, env); - - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto o = at::empty({}, TensorOptions(kCPU)); - std::optional dtype; - if (scalar_type != ScalarType::Undefined) { - dtype = static_cast(scalar_type); - } - auto ref = a.sum(/*dtype=*/dtype); - TensorExprKernel k(graph); - std::vector inputs = {a}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for -# CHECK-NEXT: for)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - ASSERT_EQ(o.sizes(), ref.sizes()); - ASSERT_EQ(o.dtype(), ref.dtype()); - ASSERT_TRUE(at::allclose(o, ref)); - } -} - -std::string li_to_str(at::ArrayRef li) { - std::stringstream out; - bool first = true; - for (auto elem : li) { - if (!first) { - out << ", "; - } - out << elem; - first = false; - } - return out.str(); -} - -TEST_F(Kernel, SumOneAxis) { - // Test lowering of sum on one axis. - const auto graph_template = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)): - %1 : int[] = prim::Constant[value=[${dim}]]() - %2 : bool = prim::Constant[value=${keepdim}]() - %3 : ${dtype} - %4 : ${out_dtype}(${size}, strides=[${strides}], device=cpu) = aten::sum(%0, %1, %2, %3) - return (%4))IR"; - auto a = iotaTensor({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - - for (int dim = -a.dim(); dim < a.dim(); ++dim) { - for (bool keepdim : {false, true}) { - for (auto scalar_type : {ScalarType::Undefined, ScalarType::Double}) { - at::jit::TemplateEnv env; - env.d("dim", dim); - env.d("keepdim", keepdim); - env.s("dtype", dtypeConstant(scalar_type)); - std::optional dtype; - if (scalar_type != ScalarType::Undefined) { - dtype = static_cast(scalar_type); - } - auto ref = a.sum({dim}, /*keepdim=*/keepdim, /*dtype=*/dtype); - if (scalar_type == ScalarType::Undefined) { - env.s("out_dtype", "Float"); - } else { - env.s("out_dtype", "Double"); - } - env.s("size", li_to_str(ref.sizes())); - env.s("strides", li_to_str(ref.strides())); - const auto graph_string = format(graph_template, env); - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto o = at::empty({}, TensorOptions(kCPU)); - TensorExprKernel k(graph); - std::vector inputs = {a}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for (int64_t -# CHECK-NEXT: sum -# CHECK-NEXT: for (int64_t -# CHECK-NEXT: sum)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - ASSERT_EQ(o.sizes(), ref.sizes()); - ASSERT_EQ(o.dtype(), ref.dtype()); - ASSERT_TRUE(at::allclose(o, ref, 4E-3, 4E-3)); - } - } - } -} - -TEST_F(Kernel, SumMultipleAxes) { - // Test lowering of sum on multiple axes. - const auto graph_template = R"IR( - graph(%0 : Float(2, 3, 2, 3, strides=[18, 6, 3, 1], requires_grad=0, device=cpu)): - %1 : int = prim::Constant[value=${dim1}]() - %2 : int = prim::Constant[value=${dim2}]() - %3 : int[] = prim::ListConstruct(%1, %2) - %4 : bool = prim::Constant[value=${keepdim}]() - %5 : ${dtype} - %6 : Float(${size}, strides=[${strides}], requires_grad=0, device=cpu) = aten::sum(%0, %3, %4, %5) - return (%6))IR"; - auto a = iotaTensor({2, 3, 2, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - - // Only iterate over positive values of axes to keep the running time - // reasonable, since the number of pairs is quadratic. - for (const auto dim1 : c10::irange(a.dim())) { - for (int dim2 = dim1 + 1; dim2 < a.dim(); ++dim2) { - for (bool keepdim : {false, true}) { - at::jit::TemplateEnv env; - env.d("dim1", dim1); - env.d("dim2", dim2); - env.d("keepdim", keepdim); - env.s("dtype", dtypeConstant(ScalarType::Undefined)); - auto o = at::empty({}, TensorOptions(kCPU)); - auto ref = a.sum(IntArrayRef{dim1, dim2}, /*keepdim=*/keepdim); - - env.s("size", li_to_str(ref.sizes())); - env.s("strides", li_to_str(ref.strides())); - - const auto graph_string = format(graph_template, env); - - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - TensorExprKernel k(graph); - std::vector inputs = {a}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = - R"IR( -# CHECK: for (int64_t -# CHECK: for (int64_t -# CHECK: for (int64_t -# CHECK: for (int64_t -# CHECK: sum)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - ASSERT_EQ(o.sizes(), ref.sizes()); - ASSERT_EQ(o.dtype(), ref.dtype()); - ASSERT_TRUE(at::allclose(o, ref)); - } - } - } -} - -// This test and the following ones testing Softmax only tests with dim set -// to one of the valid input dimensions. It does not test with dim=None -// because that is supposed to be deprecated. -TEST_F(Kernel, Softmax2D) { - const auto graph_template = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)): - %1 : int = prim::Constant[value=${dim}]() - %dt_float : int = prim::Constant[value=7]() - %dt_none : NoneType = prim::Constant() - %4 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %${dt}) - return (%4))IR"; - - auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - - const std::string& verification_template = - R"IR( - # CHECK: for (int i${other_dim} = 0; i${other_dim} < ${other_dim_size} - # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size} - # CHECK-NEXT: aten_softmax_max - # CHECK: for (int i${other_dim}_1 = 0; i${other_dim}_1 < ${other_dim_size} - # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size} - # CHECK-NEXT: aten_softmax_sum - # CHECK: for (int i0_2 = 0; i0_2 < 5 - # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 3 - # CHECK-NEXT: aten_softmax)IR"; - - for (bool empty_dtype : {false, true}) { - for (auto log_softmax : {false, true}) { - for (const auto softmax_dim : c10::irange(a.dim())) { - auto softmax_dim_size = a.sizes()[softmax_dim]; - auto other_dim = (softmax_dim + 1) % a.dim(); - auto ref = - log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim); - at::jit::TemplateEnv env; - env.d("dim", softmax_dim); - env.s("op", log_softmax ? "log_softmax" : "softmax"); - env.s("size", li_to_str(ref.sizes())); - env.s("strides", li_to_str(ref.strides())); - env.s("dt", empty_dtype ? "dt_none" : "dt_float"); - - const auto graph_string = format(graph_template, env); - - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - TensorExprKernel k(graph); - std::vector inputs = {a}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - at::jit::TemplateEnv ver_env; - ver_env.d("other_dim", other_dim); - ver_env.d("other_dim_size", a.sizes()[other_dim]); - ver_env.d("softmax_dim", softmax_dim); - ver_env.d("softmax_dim_size", softmax_dim_size); - const auto verification_pattern = - format(verification_template, ver_env); - - // verification string temporarily disabled until - // inlining of exp() is benchmarked and determined - // torch::jit::testing::FileCheck().run(verification_pattern, - // oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - auto output = stack[0].toTensor(); - ASSERT_EQ(output.sizes(), ref.sizes()); - ASSERT_TRUE(at::allclose(output, ref)); - } - } - } -} - -TEST_F(Kernel, Softmax3D) { - const auto graph_template = R"IR( - graph(%0 : Float(3, 4, 5, strides=[20, 5, 1], device=cpu)): - %1 : int = prim::Constant[value=${dim}]() - %2 : int = prim::Constant[value=7]() - %3 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %2) - return (%3))IR"; - - auto a = at::rand({3, 4, 5}, TensorOptions(kCPU).dtype(at::kFloat)); - - const std::string& verification_template = - R"IR( - # CHECK: for (int i${dim1} = 0; i${dim1} < ${dim1_size} - # CHECK-NEXT: for (int i${dim2} = 0; i${dim2} < ${dim2_size} - # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size} - # CHECK-NEXT: aten_softmax_max - # CHECK: for (int i${dim1}_1 = 0; i${dim1}_1 < ${dim1_size} - # CHECK-NEXT: for (int i${dim2}_1 = 0; i${dim2}_1 < ${dim2_size} - # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size} - # CHECK-NEXT: aten_softmax_sum - # CHECK: for (int i0_2 = 0; i0_2 < 3 - # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 4 - # CHECK-NEXT: for (int i2_2 = 0; i2_2 < 5 - # CHECK-NEXT: aten_softmax)IR"; - - for (auto log_softmax : {false, true}) { - for (const auto softmax_dim : c10::irange(a.dim())) { - auto softmax_dim_size = a.sizes()[softmax_dim]; - std::vector other_dims; - for (const auto i : c10::irange(a.dim())) { - if (i != softmax_dim) { - other_dims.push_back(i); - } - } - auto ref = - log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim); - - at::jit::TemplateEnv env; - env.d("dim", softmax_dim); - env.s("op", log_softmax ? "log_softmax" : "softmax"); - env.s("size", li_to_str(ref.sizes())); - env.s("strides", li_to_str(ref.strides())); - - const auto graph_string = format(graph_template, env); - - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - TensorExprKernel k(graph); - std::vector inputs = {a}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - at::jit::TemplateEnv ver_env; - ver_env.d("dim1", other_dims[0]); - ver_env.d("dim1_size", a.sizes()[other_dims[0]]); - ver_env.d("dim2", other_dims[1]); - ver_env.d("dim2_size", a.sizes()[other_dims[1]]); - ver_env.d("softmax_dim", softmax_dim); - ver_env.d("softmax_dim_size", softmax_dim_size); - const auto verification_pattern = format(verification_template, ver_env); - - // verification string temporarily disabled until - // inlining of exp() is benchmarked and determined - // torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - auto output = stack[0].toTensor(); - - ASSERT_EQ(output.sizes(), ref.sizes()); - ASSERT_TRUE(at::allclose(output, ref)); - } - } -} - -TEST_F(Kernel, Softmax4D) { - const auto graph_template = R"IR( - graph(%0 : Float(2, 3, 2, 3, strides=[18, 6, 3, 1], device=cpu)): - %1 : int = prim::Constant[value=${dim}]() - %2 : int = prim::Constant[value=7]() - %3 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %2) - return (%3))IR"; - - auto a = at::rand({2, 3, 2, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - - const std::string& verification_template = - R"IR( - # CHECK: for (int i${dim1} = 0; i${dim1} < ${dim1_size} - # CHECK-NEXT: for (int i${dim2} = 0; i${dim2} < ${dim2_size} - # CHECK-NEXT: for (int i${dim3} = 0; i${dim3} < ${dim3_size} - # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size} - # CHECK-NEXT: aten_softmax_max - # CHECK: for (int i${dim1}_1 = 0; i${dim1}_1 < ${dim1_size} - # CHECK-NEXT: for (int i${dim2}_1 = 0; i${dim2}_1 < ${dim2_size} - # CHECK-NEXT: for (int i${dim3}_1 = 0; i${dim3}_1 < ${dim3_size} - # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size} - # CHECK-NEXT: aten_softmax_sum - # CHECK: for (int i0_2 = 0; i0_2 < 2 - # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 3 - # CHECK-NEXT: for (int i2_2 = 0; i2_2 < 2 - # CHECK-NEXT: for (int i3_2 = 0; i3_2 < 3 - # CHECK-NEXT: aten_softmax)IR"; - - for (auto log_softmax : {false, true}) { - for (const auto softmax_dim : c10::irange(a.dim())) { - auto softmax_dim_size = a.sizes()[softmax_dim]; - std::vector other_dims; - for (const auto i : c10::irange(a.dim())) { - if (i != softmax_dim) { - other_dims.push_back(i); - } - } - auto ref = - log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim); - - at::jit::TemplateEnv env; - env.d("dim", softmax_dim); - env.s("op", log_softmax ? "log_softmax" : "softmax"); - env.s("size", li_to_str(ref.sizes())); - env.s("strides", li_to_str(ref.strides())); - - const auto graph_string = format(graph_template, env); - - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - TensorExprKernel k(graph); - std::vector inputs = {a}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - at::jit::TemplateEnv ver_env; - ver_env.d("dim1", other_dims[0]); - ver_env.d("dim1_size", a.sizes()[other_dims[0]]); - ver_env.d("dim2", other_dims[1]); - ver_env.d("dim2_size", a.sizes()[other_dims[1]]); - ver_env.d("dim3", other_dims[2]); - ver_env.d("dim3_size", a.sizes()[other_dims[2]]); - ver_env.d("softmax_dim", softmax_dim); - ver_env.d("softmax_dim_size", softmax_dim_size); - const auto verification_pattern = format(verification_template, ver_env); - - // verification string temporarily disabled until - // inlining of exp() is benchmarked and determined - // torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - auto output = stack[0].toTensor(); - ASSERT_EQ(output.sizes(), ref.sizes()); - ASSERT_TRUE(at::allclose(output, ref)); - } - } -} - -TEST_F(Kernel, SignTest) { - const auto graph_template = R"IR( - graph(%0 : ${dtype}(${size}, strides=[1], device=cpu)): - %2 : ${dtype}(${size}, strides=[1]) = aten::sign(%0) - return (%2))IR"; - - auto run_test = [](const std::string& graph_string, const at::Tensor& input) { - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - TensorExprKernel k(graph); - StmtPtr s = k.getCodeGenStmt(); - - std::vector inputs = {input}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - auto ref = at::sign(input); - ASSERT_TRUE(at::allclose(o, ref)); - }; - auto common_options = at::TensorOptions() - .layout(at::kStrided) - .device(at::kCPU) - .requires_grad(false); - int default_input_size = 100; - for (auto scalar_type : {ScalarType::Float, ScalarType::Double}) { - at::Tensor corner_case_inputs; - at::jit::TemplateEnv env; - auto options = common_options; - switch (scalar_type) { - case ScalarType::Float: { - env.s("dtype", "Float"); - options = options.dtype(at::kFloat); - std::vector input_float = { - 0.0f, - -0.0f, - std::numeric_limits::infinity(), - -std::numeric_limits::infinity(), - std::nanf("1"), - -std::nanf("1")}; - corner_case_inputs = at::from_blob( - input_float.data(), - {static_cast(input_float.size())}, - options); - auto rand_input = at::rand({default_input_size}, options); - auto input = at::cat({rand_input, corner_case_inputs}); - env.d("size", at::numel(input)); - const auto graph_string = format(graph_template, env); - run_test(graph_string, input); - break; - } - case ScalarType::Double: { - env.s("dtype", "Double"); - options = options.dtype(at::kDouble); - std::vector input_double = { - 0.0, - -0.0, - std::numeric_limits::infinity(), - -std::numeric_limits::infinity(), - std::nan("1"), - -std::nan("1")}; - corner_case_inputs = at::from_blob( - input_double.data(), - {static_cast(input_double.size())}, - options); - auto rand_input = at::rand({default_input_size}, options); - auto input = at::cat({rand_input, corner_case_inputs}); - env.d("size", at::numel(input)); - const auto graph_string = format(graph_template, env); - run_test(graph_string, input); - break; - } - default: - throw unsupported_dtype(); - } - } -} - -TEST_F(Kernel, InlineProducerIntoReduction) { - // Inline producer (mul) into reduction (sum). - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), - %1 : Float(5, 3, strides=[3, 1], device=cpu)): - %2 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %1) - %3 : int = prim::Constant[value=7]() - %4 : Double(device=cpu) = aten::sum(%2, %3) - return (%4))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - TensorExprKernel k(graph); - StmtPtr s = k.getCodeGenStmt(); - std::ostringstream oss; - oss << *s; - - // Check the IR we produced. - // We should have only one loop in the end. - const std::string& verification_pattern = - R"IR( - # CHECK: for (int64_t i_1 = 0ll; i_1 < 5 - # CHECK-NEXT: for (int64_t j_1 = 0ll; j_1 < 3 - # CHECK-NEXT: sum - # CHECK-NOT: for)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - std::vector inputs = {a, b}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - auto ref = (a * b).sum(at::kDouble); - ASSERT_TRUE(at::allclose(o, ref)); -} - -TEST_F(Kernel, InlineReductionIntoConsumer) { - // Inline producer (mul %2) into reduction (sum %4) but DO NOT - // inline the reduction into consumer (mul %4). - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), - %1 : Float(5, 3, strides=[3, 1], device=cpu)): - %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) - %3 : int = prim::Constant[value=6]() - %4 : Float(device=cpu) = aten::sum(%2, %3) - %5 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%2, %4) - return (%5))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - TensorExprKernel k(graph); - StmtPtr s = k.getCodeGenStmt(); - std::ostringstream oss; - oss << *s; - - // Check the IR we produced. - // We should have two loops in the end. - const std::string& verification_pattern = - R"IR( - # CHECK: for (int64_t i_1 = 0ll; i_1 < 5 - # CHECK-NEXT: for (int64_t j_1 = 0ll; j_1 < 3 - # CHECK-NEXT: sum - # CHECK: for (int64_t i_2 = 0ll; i_2 < 5 - # CHECK-NEXT: for (int64_t j_2 = 0ll; j_2 < 3 - # CHECK-NEXT: aten_mul - # CHECK-NOT: for)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - std::vector inputs = {a, b}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - auto ref = (a * b).sum(at::kFloat) * (a * b); - ASSERT_TRUE(at::allclose(o, ref)); -} - -TEST_F(Kernel, SanitizeNames_CUDA) { - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cuda:0), - %1 : Float(5, 3, strides=[3, 1], device=cuda:0)): - %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) - %4 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) - return (%4))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - graph->inputs().at(0)->setDebugName("aten::add:"); - graph->inputs().at(1)->setDebugName("aten::add_"); - TensorExprKernel k(graph); - auto a = at::rand({5, 3}, TensorOptions(kCUDA).dtype(at::kFloat)); - auto b = at::rand({5, 3}, TensorOptions(kCUDA).dtype(at::kFloat)); - auto ref = a * (a * b); - std::vector inputs = {a, b}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); -} - -TEST_F(Kernel, SanitizeConstants_CUDA) { - const auto graph_string = R"IR( - graph(%x : Float(16, 16, strides=[16, 1], device=cuda:0)): - %none : NoneType = prim::Constant() - %size : int = prim::Constant[value=16]() - %sizes : int[] = prim::ListConstruct(%size, %size) - %30 : Device = prim::Constant[value="cuda"]() - %y : Float(16, 16, strides=[16, 1], device=cuda:0) = aten::ones(%sizes, %none, %none, %30, %none) - %z : Float(16, 16, strides=[16, 1], device=cuda:0) = aten::mul(%x, %y) - return (%z))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - // IRParser doesn't support tensor constants, so we insert a call to - // aten::ones and then const-prop it - ConstantPropagation(graph); - - // We set the name of the constant to include special characters that are - // not allowed. This should be fixed by the sanitizer in TensorExprKernel. - graph->nodes().front()->output()->setDebugName("illegal.name"); - - // Check if we have a constant node with illegal name in the graph. - auto const_node = graph->nodes().front(); - ASSERT_EQ(const_node->kind(), prim::Constant); - ASSERT_NE(const_node->output()->debugName().find('.'), std::string::npos); - - TensorExprKernel k(graph); - - auto x = at::rand({16, 16}, TensorOptions(kCUDA).dtype(at::kFloat)); - std::vector inputs = {x}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - auto y = at::ones({16, 16}, TensorOptions(kCUDA).dtype(at::kFloat)); - auto ref = x * y; - ASSERT_TRUE(at::allclose(o, ref)); -} - -TEST_F(Kernel, ConstantTensors) { - const auto graph_string = R"IR( - graph(%x : Float(16, 16, strides=[16, 1], device=cpu)): - %none : NoneType = prim::Constant() - %size : int = prim::Constant[value=16]() - %sizes : int[] = prim::ListConstruct(%size, %size) - %y : Float(16, 16, strides=[16, 1], device=cpu) = aten::ones(%sizes, %none, %none, %none, %none) - %z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y) - return (%z))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - // IRParser doesn't support tensor constants, so we insert a call to - // aten::ones and then const-prop it - ConstantPropagation(graph); - - TensorExprKernel k(graph); - - auto x = at::rand({16, 16}, TensorOptions(kCPU).dtype(at::kFloat)); - std::vector inputs = {x}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - auto y = at::ones({16, 16}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = x * y; - ASSERT_TRUE(at::allclose(o, ref)); -} - -TEST_F(Kernel, ConstantTensorsNonContiguous) { - const auto graph_string = R"IR( - graph(%x : Float(16, 16, strides=[16, 1], device=cpu)): - %none : NoneType = prim::Constant() - %dtype : int = prim::Constant[value=6]() - %c0 : int = prim::Constant[value=0]() - %c256 : int = prim::Constant[value=256]() - %c16 : int = prim::Constant[value=16]() - %y_flat : Tensor = aten::arange(%c0, %c256, %dtype, %none, %none, %none) - %sizes : int[] = prim::ListConstruct(%c16, %c16) - %y_t : Tensor = aten::view(%y_flat, %sizes) - %y : Tensor = aten::t(%y_t) - %z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y) - return (%z))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - // IRParser doesn't support tensor constants, so we generate several aten - // calls to produce non-contiguous constant tensor and then const-prop it - ConstantPropagation(graph); - - TensorExprKernel k(graph); - - auto x = at::rand({16, 16}, TensorOptions(kCPU).dtype(at::kFloat)); - std::vector inputs = {x}; - std::vector stack = fmap(inputs); - k.run(stack); - auto o = stack[0].toTensor(); - auto y = at::arange(0, 256, TensorOptions(kCPU).dtype(at::kFloat)) - .view({16, 16}) - .t(); - auto ref = x * y; - ASSERT_TRUE(at::allclose(o, ref)); -} - -TEST_F(Kernel, RunFast) { -#ifdef TORCH_ENABLE_LLVM - // TODO: Implement call_raw in IREval and remove the ifdef - - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), - %1 : Float(5, 3, strides=[1, 5], device=cpu)): - %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) - %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = - at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1); - auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = a * (a * b); - TensorExprKernel k(graph); - - k.runFast({a.data_ptr(), b.data_ptr()}, {o.data_ptr()}); - for (size_t i = 0; i < 5 * 3; i++) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } -#endif -} - -TEST_F(Kernel, RunWithAllocatedOutputs) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), - %1 : Float(5, 3, strides=[1, 5], device=cpu)): - %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) - %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = - at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1); - auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = a * (a * b); - TensorExprKernel k(graph); - - std::vector args = {o, a, b}; - std::vector stack = fmap(args); - k.runWithAllocatedOutputs(stack); - for (size_t i = 0; i < 5 * 3; i++) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } -#endif -} - -TEST_F(Kernel, CodegenInspection) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%x : Float(16, 16, strides=[16, 1], device=cpu)): - %none : NoneType = prim::Constant() - %dtype : int = prim::Constant[value=6]() - %c0 : int = prim::Constant[value=0]() - %c256 : int = prim::Constant[value=256]() - %c16 : int = prim::Constant[value=16]() - %y_flat : Tensor = aten::arange(%c0, %c256, %dtype, %none, %none, %none) - %sizes : int[] = prim::ListConstruct(%c16, %c16) - %y_t : Tensor = aten::view(%y_flat, %sizes) - %y : Tensor = aten::t(%y_t) - %z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y) - return (%z))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - // IRParser doesn't support tensor constants, so we generate several aten - // calls to produce non-contiguous constant tensor and then const-prop it - ConstantPropagation(graph); - - TensorExprKernel k(graph); - - // Check that we could retrieve generated assembly - auto asm_str = k.getCodeText("asm"); - const std::string& asm_verification_pattern = - R"ASM( - # CHECK: .text - # CHECK: retq)ASM"; - torch::jit::testing::FileCheck().run(asm_verification_pattern, asm_str); - - // Check that we could retrieve info about codegen parameters - auto constants = k.getConstantDescriptors(); - auto buf_args = k.getBufferArgs(); - // Expected buf args: [input0, output0, constant0] - ASSERT_EQ(buf_args.size(), 3); - ASSERT_EQ(constants.size(), 1); - ASSERT_TRUE( - !buf_args[0].isVar() && !buf_args[1].isVar() && !buf_args[2].isVar()); -#endif -} - -Tensor lowerNanToNum( - const std::vector& inputs, - const std::vector& outputShape, - const std::vector& outputStrides, - const std::optional& outputType, - at::Device device) { - auto input_buf = std::get(inputs[0]); - auto e = Compute( - "custom_nan_to_num", - outputShape, - outputStrides, - [&](const std::vector& axes) { - std::vector indices(axes.begin(), axes.end()); - auto load = input_buf.load(indices); - return IfThenElse::make(Cast::make(kBool, isnan(load)), 0.0f, load); - }); - return e; -} - -TEST_F(Kernel, CustomLowering) { - const auto graph_string = R"IR( - graph(%x : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu)): - %none : NoneType = prim::Constant() - %y : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::nan_to_num(%x, %none, %none, %none) - return (%y) -)IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - std::unordered_map lowerings = { - {aten::nan_to_num, lowerNanToNum}}; - TensorExprKernel k(graph, lowerings); - - auto stmt = k.getCodeGenStmt(); - std::ostringstream oss; - oss << *stmt; - - // Check that our custom lowering is actually used - torch::jit::testing::FileCheck().check("custom_nan_to_num")->run(oss.str()); - torch::jit::testing::FileCheck().check("isnan")->run(oss.str()); -} - -TEST_F(Kernel, Vectorize) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%0 : Float(100, 16, strides=[16, 1], device=cpu), - %1 : Float(100, 16, strides=[16, 1], device=cpu)): - %2 : Float(100, 16, strides=[16, 1]) = aten::mul(%0, %1) - %3 : Float(100, 16, strides=[16, 1]) = aten::mul(%0, %2) - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({100, 16}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({100, 16}, TensorOptions(kCPU).dtype(at::kFloat)); - auto o = at::zeros({100, 16}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = a * (a * b); - TensorExprKernel k(graph); - std::vector inputs = {a, b}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = R"IR(# CHECK: Ramp)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - for (size_t i = 0; i < 100 * 16; i++) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } -#endif -} - -// TODO: To vectorize loopnest for 100x3 case, we need to flatten loops first. -TEST_F(Kernel, DISABLED_FlattenVectorize) { -#ifdef TORCH_ENABLE_LLVM - const auto graph_string = R"IR( - graph(%0 : Float(100, 3, strides=[3, 1], device=cpu), - %1 : Float(100, 3, strides=[3, 1], device=cpu)): - %2 : Float(100, 3, strides=[3, 1]) = aten::mul(%0, %1) - %3 : Float(100, 3, strides=[3, 1]) = aten::mul(%0, %2) - return (%3))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto a = at::rand({100, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({100, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto o = at::zeros({100, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto ref = a * (a * b); - TensorExprKernel k(graph); - std::vector inputs = {a, b}; - StmtPtr s = k.getCodeGenStmt(); - - std::ostringstream oss; - oss << *s; - - // Check the IR we produced - const std::string& verification_pattern = R"IR(# CHECK: Ramp)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector stack = fmap(inputs); - k.run(stack); - o = stack[0].toTensor(); - for (size_t i = 0; i < 100 * 3; i++) { - TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); - } -#endif -} - -TEST_F(Kernel, Strided1dWithinBounds) { - auto ir = R"IR( - graph(%0 : Float(3, strides=[1], device=cpu), - %1 : Float(3, strides=[2], device=cpu)): - %2 : int = prim::Constant[value=1]() - %3 : Float(3, strides=[1]) = aten::add(%0, %1, %2) - return (%3))IR"; - auto graph = std::make_shared(); - std::unordered_map vmap; - parseIR(ir, graph.get(), vmap); - TensorExprKernel k(graph); - - auto a = at::rand({3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto b = at::rand({6}, TensorOptions(kCPU).dtype(at::kFloat)) - .index({Slice(None, None, 2)}); - auto expect = a + b; - - std::vector inputs = {a, b}; - - std::vector stack = fmap(inputs); - k.run(stack); - - auto output = stack[0].toTensor(); - - for (size_t i = 0; i < 3; ++i) { - TORCH_CHECK_EQ( - ((float*)output.data_ptr())[i], ((float*)expect.data_ptr())[i]); - } -} - -TEST_F(Kernel, InputAsOutput) { - const auto graph_string = R"IR( - graph(%x : Float(5, 3, strides=[3, 1], device=cpu), - %y : Float(5, 3, strides=[1, 5], device=cpu)): - return (%x, %y))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); - auto y = - at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1); - TensorExprKernel k(graph); - std::vector inputs = {x, y}; - - std::vector stack = fmap(inputs); - k.run(stack); - CHECK(at::allclose(x, stack[0].toTensor())); - CHECK(at::allclose(y, stack[1].toTensor())); -} - -TEST_F(Kernel, ScalarOut) { - auto ir = R"IR( -graph(%x : int, %y : int): - %z : int = aten::mul(%x, %y) - %r : int = aten::mul(%z, %x) - return (%r, %z))IR"; - auto graph = std::make_shared(); - std::unordered_map vmap; - parseIR(ir, graph.get(), vmap); - TensorExprKernel k(graph); - - auto stmt = k.getCodeGenStmt(); - std::ostringstream oss; - oss << *stmt; - - // Verify the generated IR. We expect to see a scalar variable (Let) followed - // by a store to a 0-dim buffer. - const std::string& verification_pattern = R"IR( -# CHECK: int64_t -# CHECK-NEXT: [0ll] = -# CHECK-NEXT: int64_t -# CHECK-NEXT: [0ll] = -)IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - int64_t x = 2, y = 3, r = 0, z = 0; - - // Verify that TEK::runFast works correctly with scalar outputs - std::vector inputs = {&x, &y}; - std::vector outputs = {&r, &z}; - k.runFast(inputs, outputs); - TORCH_CHECK_EQ(z, x * y); - TORCH_CHECK_EQ(r, z * x); - - // Verify that TEK::run works correctly with scalar outputs - std::vector stack = {x, y}; - k.run(stack); - TORCH_CHECK_EQ(stack[0], x * y * x); - TORCH_CHECK_EQ(stack[1], x * y); -} - -TEST_F(Kernel, ScalarTensorOut) { - auto ir = R"IR( -graph(%x : int, - %xt : Long(3, strides=[1], device=cpu), - %y : int, - %yt : Long(3, strides=[1], device=cpu)): - %z : int = aten::mul(%x, %y) - %r : int = aten::mul(%z, %x) - %zt : Long(3, strides=[1], device=cpu) = aten::mul(%xt, %y) - %rt : Long(3, strides=[1], device=cpu) = aten::mul(%zt, %xt) - return (%r, %rt, %z, %zt))IR"; - auto graph = std::make_shared(); - std::unordered_map vmap; - parseIR(ir, graph.get(), vmap); - TensorExprKernel k(graph); - int64_t x = 2, y = 3, r = 0, z = 0; - auto xt = at::ones({3}, TensorOptions(kCPU).dtype(at::kLong)) * 2; - auto yt = at::ones({3}, TensorOptions(kCPU).dtype(at::kLong)) * 3; - auto zt = at::zeros({3}, TensorOptions(kCPU).dtype(at::kLong)); - auto rt = at::zeros({3}, TensorOptions(kCPU).dtype(at::kLong)); - - // Verify that TEK::runFast works correctly with mixed scalar and tensor - // inputs/outputs - std::vector inputs = {&x, xt.data_ptr(), &y, yt.data_ptr()}; - std::vector outputs = {&r, rt.data_ptr(), &z, zt.data_ptr()}; - k.runFast(inputs, outputs); - TORCH_CHECK_EQ(z, x * y); - TORCH_CHECK_EQ(r, z * x); - ASSERT_TRUE(at::equal(zt, xt * yt)); - ASSERT_TRUE(at::equal(rt, zt * xt)); - - // Verify that TEK::run works correctly with mixed scalar and tensor - // inputs/outputs - std::vector stack = {x, xt, y, yt}; - k.run(stack); - TORCH_CHECK_EQ(stack[0], x * y * x); - ASSERT_TRUE(at::equal(stack[1].toTensor(), xt * yt * xt)); - TORCH_CHECK_EQ(stack[2], x * y); - ASSERT_TRUE(at::equal(stack[3].toTensor(), xt * yt)); -} - -TEST_F(Kernel, FuseLoopsWithVariableBounds) { -#ifdef TORCH_ENABLE_LLVM - bool old_cat_wo_conditionals = getCatWoConditionals(); - getCatWoConditionals() = true; - const auto graph_string = R"IR( - graph(%a : Float(SS(-2), 3, SS(-3), requires_grad=0, device=cpu), - %b : Float(SS(-2), 7, SS(-3), requires_grad=0, device=cpu), - %c : Float(SS(-2), 9, SS(-3), requires_grad=0, device=cpu), - %SS_2 : int, - %SS_3 : int): - %dim : int = prim::Constant[value=1]() - %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) - %r : Float(SS(-2), 19, SS(-3), requires_grad=0, device=cpu) = aten::cat(%inputs, %dim) # new size: [5,19,2] - return (%r))IR"; - std::shared_ptr graph = std::make_shared(); - torch::jit::parseIR(graph_string, graph.get()); - - std::vector symbolic_shape_inputs = {-2, -3}; - - std::vector input_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[graph->inputs().at(0)] = input_desc; - symbolic_strides[graph->inputs().at(1)] = input_desc; - symbolic_strides[graph->inputs().at(2)] = input_desc; - symbolic_strides[graph->outputs().at(0)] = input_desc; - - TensorExprKernel kernel( - graph, {}, symbolic_shape_inputs, false, symbolic_strides); - - std::ostringstream oss; - oss << *kernel.getCodeGenStmt(); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int64_t i -# CHECK-NEXT: for (int64_t j -# CHECK-NEXT: for (int64_t k -# CHECK: for (int64_t j -# CHECK-NEXT: for (int64_t k -# CHECK: for (int64_t j -# CHECK-NEXT: for (int64_t k -# CHECK-NOT: for (int64_t i - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto run_kernel = [&](int dim1, int dim2) { - auto a = - at::rand({dim1, 3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); - auto b = - at::rand({dim1, 7, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); - auto c = - at::rand({dim1, 9, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); - - auto ref = at::cat({a, b, c}, 1); - - std::vector stack = - fmap(std::vector({a, b, c})); - stack.emplace_back(dim1); - stack.emplace_back(dim2); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - }; - - run_kernel(10, 20); - getCatWoConditionals() = old_cat_wo_conditionals; -#endif -} - -TEST_F(Kernel, FuseLoopsWithVariableConcatDim) { -#ifdef TORCH_ENABLE_LLVM - bool old_cat_wo_conditionals = getCatWoConditionals(); - getCatWoConditionals() = true; - const auto graph_string = R"IR( - graph(%a : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu), - %b : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu), - %c : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu), - %SS_2 : int, - %SS_3 : int, - %SS_4 : int, - %SS_5 : int): - %dim : int = prim::Constant[value=1]() - %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) - %r : Float(SS(-2), SS(-5), SS(-3), requires_grad=0, device=cpu) = aten::cat(%inputs, %dim) # new size: [5,19,2] - return (%r))IR"; - std::shared_ptr graph = std::make_shared(); - torch::jit::parseIR(graph_string, graph.get()); - - std::vector symbolic_shape_inputs = {-2, -3, -4, -5}; - - std::vector input_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[graph->inputs().at(0)] = input_desc; - symbolic_strides[graph->inputs().at(1)] = input_desc; - symbolic_strides[graph->inputs().at(2)] = input_desc; - symbolic_strides[graph->outputs().at(0)] = input_desc; - - TensorExprKernel kernel( - graph, {}, symbolic_shape_inputs, false, symbolic_strides); - - std::ostringstream oss; - oss << *kernel.getCodeGenStmt(); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int64_t i -# CHECK-NEXT: for (int64_t j -# CHECK-NEXT: for (int64_t k -# CHECK: for (int64_t j -# CHECK-NEXT: for (int64_t k -# CHECK: for (int64_t j -# CHECK-NEXT: for (int64_t k -# CHECK-NOT: for (int64_t i - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto run_kernel = [&](int dim1, int dim2, int dim3) { - auto a = - at::rand({dim1, dim3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); - auto b = - at::rand({dim1, dim3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); - auto c = - at::rand({dim1, dim3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); - - auto ref = at::cat({a, b, c}, 1); - - std::vector stack = - fmap(std::vector({a, b, c})); - stack.emplace_back(dim1); - stack.emplace_back(dim2); - stack.emplace_back(dim3); - stack.emplace_back(3 * dim3); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - }; - - run_kernel(10, 20, 15); - getCatWoConditionals() = old_cat_wo_conditionals; -#endif -} - -TEST_F(Kernel, DoNotFuseLoopsWithMismatchingVariableDims) { -#ifdef TORCH_ENABLE_LLVM - bool old_cat_wo_conditionals = getCatWoConditionals(); - getCatWoConditionals() = true; - const auto graph_string = R"IR( - graph(%a : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu), - %b : Float(SS(-2), SS(-5), SS(-3), requires_grad=0, device=cpu), - %SS_2 : int, - %SS_3 : int, - %SS_4 : int, - %SS_5 : int, - %SS_6 : int): - %dim : int = prim::Constant[value=1]() - %inputs : Tensor[] = prim::ListConstruct(%a, %b) - %r : Float(SS(-2), SS(-6), SS(-3), requires_grad=0, device=cpu) = aten::cat(%inputs, %dim) # new size: [5,19,2] - return (%r))IR"; - std::shared_ptr graph = std::make_shared(); - torch::jit::parseIR(graph_string, graph.get()); - - std::vector symbolic_shape_inputs = {-2, -3, -4, -5, -6}; - - std::vector input_desc = { - torch::jit::StrideInput::TENSOR_CONT}; - std::unordered_map< - const torch::jit::Value*, - std::vector> - symbolic_strides; - symbolic_strides[graph->inputs().at(0)] = input_desc; - symbolic_strides[graph->inputs().at(1)] = input_desc; - symbolic_strides[graph->outputs().at(0)] = input_desc; - - TensorExprKernel kernel( - graph, {}, symbolic_shape_inputs, false, symbolic_strides); - - std::ostringstream oss; - oss << *kernel.getCodeGenStmt(); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int64_t i -# CHECK-NEXT: for (int64_t j -# CHECK-NEXT: for (int64_t k -# CHECK: for (int64_t j -# CHECK-NEXT: for (int64_t k -# CHECK-NOT: for (int64_t j -# CHECK-NOT: for (int64_t i - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - auto run_kernel = [&](int dim2, int dim3, int dim4, int dim5) { - auto a = - at::rand({dim2, dim4, dim3}, at::TensorOptions(kCPU).dtype(at::kFloat)); - auto b = - at::rand({dim2, dim5, dim3}, at::TensorOptions(kCPU).dtype(at::kFloat)); - - auto ref = at::cat({a, b}, 1); - - std::vector stack = fmap(std::vector({a, b})); - stack.emplace_back(dim2); - stack.emplace_back(dim3); - stack.emplace_back(dim4); - stack.emplace_back(dim5); - stack.emplace_back(dim4 + dim5); - kernel.run(stack); - - auto o = stack[0].toTensor(); - ASSERT_TRUE(at::allclose(o, ref)); - }; - - run_kernel(10, 20, 15, 8); - getCatWoConditionals() = old_cat_wo_conditionals; -#endif -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_llvm.cpp b/test/cpp/tensorexpr/test_llvm.cpp deleted file mode 100644 index f6ffc84f62c0..000000000000 --- a/test/cpp/tensorexpr/test_llvm.cpp +++ /dev/null @@ -1,1799 +0,0 @@ -#ifdef TORCH_ENABLE_LLVM -#include - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -namespace torch { -namespace jit { -using namespace torch::jit::tensorexpr; - -using LLVMExprEval = ExprEval; - -// Typed tests, can't use gtest params here due to the way we instantiate tests. -#define TEST_LLVM_SCALAR_TYPES(_) \ - _(uint8_t, Byte, 24) \ - _(int8_t, Char, -20) \ - _(int16_t, Short, 3332) \ - _(int, Int, 123456) \ - _(int64_t, Long, 2631563121321) \ - _(float, Float, 0.122) \ - _(double, Double, 0.21312) \ - _(at::Half, Half, 0.128f) - -#define IMM_TEST(Type, Name, Val) \ - TEST(LLVM, Name##ImmTest) { \ - auto a = Name##Imm::make(Val); \ - LLVMExprEval cg(a); \ - if (std::is_floating_point()) { \ - ASSERT_NEAR(cg.value(), Val, 0.1); \ - } else { \ - ASSERT_EQ(cg.value(), Val); \ - } \ - } -TEST_LLVM_SCALAR_TYPES(IMM_TEST) -#undef IMM_TEST - -#define ADD_TEST(Type, Name, Val) \ - TEST(LLVM, Name##AddTest) { \ - auto a = Name##Imm::make(Val); \ - auto b = Name##Imm::make(Val * 2); \ - auto c = Add::make(a, b); \ - LLVMExprEval cg(c); \ - if (std::is_floating_point()) { \ - ASSERT_NEAR(cg.value(), Val * 3, 0.1); \ - } else { \ - ASSERT_EQ(cg.value(), Val * 3); \ - } \ - } -TEST_LLVM_SCALAR_TYPES(ADD_TEST) -#undef ADD_TEST - -#define SUB_TEST(Type, Name, Val) \ - TEST(LLVM, Name##SubTest) { \ - auto a = Name##Imm::make(Val * 2); \ - auto b = Name##Imm::make(Val); \ - auto c = Sub::make(a, b); \ - LLVMExprEval cg(c); \ - if (std::is_floating_point()) { \ - ASSERT_NEAR(cg.value(), Val, 0.1); \ - } else { \ - ASSERT_EQ(cg.value(), Val); \ - } \ - } -TEST_LLVM_SCALAR_TYPES(SUB_TEST) -#undef SUB_TEST - -#define MUL_TEST(Type, Name, Val) \ - TEST(LLVM, Name##MulTest) { \ - auto a = Name##Imm::make(Val); \ - auto b = Name##Imm::make((Type)4); \ - auto c = Mul::make(a, b); \ - LLVMExprEval cg(c); \ - if (std::is_floating_point()) { \ - ASSERT_NEAR(cg.value(), Val * 4, 0.1); \ - } else { \ - ASSERT_EQ(cg.value(), Val * 4); \ - } \ - } -TEST_LLVM_SCALAR_TYPES(MUL_TEST) -#undef MUL_TEST - -#define DIV_TEST(Type, Name, Val) \ - TEST(LLVM, Name##DivTest) { \ - auto a = Name##Imm::make((Type)6); \ - auto b = Name##Imm::make((Type)3); \ - auto c = Div::make(a, b); \ - LLVMExprEval cg(c); \ - if (std::is_floating_point()) { \ - ASSERT_NEAR(cg.value(), 2, 0.1); \ - } else { \ - ASSERT_EQ(cg.value(), 2); \ - } \ - } -TEST_LLVM_SCALAR_TYPES(DIV_TEST) -#undef DIV_TEST - -TEST(LLVM, IntToFloatCastTest) { - auto a = IntImm::make(2); - auto b = Cast::make(kFloat, a); - LLVMExprEval cg(b, {}); - ASSERT_EQ(cg.value(), 2.0); -} - -TEST(LLVM, FloatToIntCastTest) { - auto a = FloatImm::make(2.0); - auto b = Cast::make(kInt, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), 2); -} - -TEST(LLVM, IntToLongCastTest) { - auto a = IntImm::make(12345); - auto b = Cast::make(kLong, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), 12345); -} - -TEST(LLVM, ByteToCharCastTest) { - auto a = ByteImm::make(250); - auto b = Cast::make(kChar, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), (int8_t)250); -} - -TEST(LLVM, HalfToLongCastTest) { - auto a = HalfImm::make(2.0); - auto b = Cast::make(kLong, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), 2); -} - -TEST(LLVM, ByteToDoubleCastTest) { - auto a = ByteImm::make(2); - auto b = Cast::make(kDouble, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), 2); -} - -TEST(LLVM, FloatToByteCastTest) { - auto a = FloatImm::make(254.0); - auto b = Cast::make(kByte, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), 254); -} - -TEST(LLVM, FloatToCharCastTest) { - auto a = FloatImm::make(-2.0); - auto b = Cast::make(kChar, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), -2); -} - -TEST(LLVM, ByteToFloatCastTest) { - auto a = ByteImm::make(254); - auto b = Cast::make(kFloat, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), 254.0); -} - -TEST(LLVM, CharToFloatCastTest) { - auto a = CharImm::make(-2); - auto b = Cast::make(kFloat, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), -2.0); -} - -TEST(LLVM, BitCast) { - /* constexpr int16_t ref16 = 1337; */ - constexpr int32_t ref32 = 1337; - constexpr int64_t ref64 = 1337; - constexpr float reff32 = 1337.0f; - constexpr double reff64 = 1337.0f; - - // this is broken - /*{ - at::Half k_; - at::Half* k = &k_; - *reinterpret_cast(k) = ref16; - auto a = HalfImm::make(k); - auto b = BitCast::make(kShort, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), ref16); - }*/ - - { - float k = raw_bitcast(ref32); - auto a = FloatImm::make(k); - auto b = BitCast::make(kInt, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), ref32); - } - - { - double k = raw_bitcast(ref64); - auto a = DoubleImm::make(k); - auto b = BitCast::make(kLong, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), ref64); - } - - { - int64_t k = raw_bitcast(reff64); - auto a = LongImm::make(k); - auto b = BitCast::make(kDouble, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), reff64); - } - - { - int32_t k = raw_bitcast(reff32); - auto a = IntImm::make(k); - auto b = BitCast::make(kFloat, a); - LLVMExprEval cg(b); - ASSERT_EQ(cg.value(), reff32); - } -} - -TEST(LLVM, fastLogFloat) { - const int kTotalSize = 128 * 128; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); - - VarHandle index = VarHandle("index", kInt); - ExprHandle load_a = a_buf.load(index); - StmtPtr store_b = b_buf.store({index}, fast_log(load_a)); - StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); - - PaddedBuffer a_v(kTotalSize); - PaddedBuffer b_v(kTotalSize); - - for (const auto i : c10::irange(kTotalSize)) { - a_v(i) = at::randn({1}).item().to(); - } - - LLVMCodeGen ir_eval(stmt, {a_buf, b_buf}); - ir_eval.call({a_v, b_v}); - - for (const auto i : c10::irange(kTotalSize)) { - auto test = b_v(i); - auto ref = std::log(a_v(i)); - if (std::isnan(ref)) { - ASSERT_EQ(std::isnan(test), true); - } else { - ASSERT_FLOAT_EQ(test, ref); - } - } -} - -TEST(LLVM, LetTest01) { - BufHandle a("A", {1}, kFloat); - std::vector v = {1, 0}; - std::vector args({v.data()}); - VarHandle x("x", kFloat); - auto block = Block::make({ - Let::make(x, 3.f), - a.store({0}, ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f))), - }); - - LLVMCodeGen cg(block, {a}); - ASSERT_EQ(cg.value(args), 0); - ASSERT_EQ(v[0], 2.f + 3.f * 3.f + 4.f); -} - -TEST(LLVM, LetTest02) { - BufHandle a("A", {1}, kFloat); - std::vector v = {1, 0}; - std::vector args({v.data()}); - VarHandle x("x", kFloat); - VarHandle y("y", kFloat); - auto block = Block::make( - {Let::make(x, 3.f), - Let::make(y, 6.f), - a.store( - {IntImm::make(0)}, - ExprHandle(2.f) + (x * ExprHandle(3.f) + y * ExprHandle(4.f)))}); - - LLVMCodeGen cg(block, {a}); - ASSERT_EQ(cg.value(args), 0); - ASSERT_EQ(v[0], 2.f + 3.f * 3.f + 6.f * 4.f); -} - -TEST(LLVM, LetTestMultitype) { - BufHandle a("A", {1}, kDouble); - std::vector v = {1, 0}; - std::vector args({v.data()}); - VarHandle x("x", kByte); - VarHandle y("y", kHalf); - auto block = Block::make( - {Let::make(x, 3), - Let::make(y, 6.f), - a.store( - {0}, - Cast::make( - kDouble, - ExprHandle(2.f) + - (x * ExprHandle(3.f) + y * ExprHandle(4.f))))}); - - LLVMCodeGen cg(block, {a}); - ASSERT_EQ(cg.value(args), 0); - ASSERT_EQ(v[0], 2.f + 3 * 3.f + 6.f * 4.f); -} - -TEST(LLVM, BufferTest) { - BufHandle a("A", {32}, kFloat); - std::vector v(5); - std::vector args({v.data()}); - auto rv = IntImm::make(0); - LLVMExprEval cg(rv, {a}); - ASSERT_EQ(cg.value(args), 0); -} - -TEST(LLVM, BlockTest) { - BufHandle a("A", {32}, kInt); - std::vector v = {1, 2}; - std::vector args({v.data()}); - - auto block = Block::make({ - a.store({0}, 3), - a.store({1}, 4), - a.store({0}, 4), - }); - - LLVMCodeGen cg(block, {a}); - ASSERT_EQ(cg.value(args), 0); - ASSERT_EQ(v[0], 4); - ASSERT_EQ(v[1], 4); -} - -TEST(LLVM, LoadStoreTest) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - std::vector a_buffer = {42}; - std::vector b_buffer = {-11}; - - auto store = b.store({0}, a.load(0)); - LLVMCodeGen cg(store, {a, b}); - std::vector args({a_buffer.data(), b_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - ASSERT_EQ(a_buffer[0], 42); - ASSERT_EQ(b_buffer[0], 42); -} - -TEST(LLVM, IfThenElseTest) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - BufHandle c("C", {1}, kInt); - std::vector a_buffer = {42}; - std::vector b_buffer = {-11}; - std::vector c_buffer = {1}; - - auto store = b.store({0}, IfThenElse::make(c.load(0), a.load(0), 0)); - LLVMCodeGen cg(store, {a, b, c}); - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - ASSERT_EQ(a_buffer[0], 42); - ASSERT_EQ(b_buffer[0], 42); -} - -// if (x < 10) x = x + 1 -TEST(LLVM, CondNoFalseBlockTest) { - BufHandle x("X", {1}, kInt); - auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); - auto cond = Cond::make(cmp, x.store({0}, x.load(0) + 1), nullptr); - - for (int32_t x_value : {0, 10, 20}) { - std::vector x_buffer = {x_value}; - std::vector args({x_buffer.data()}); - LLVMCodeGen cg(cond, {x}); - ASSERT_EQ(cg.value(args), 0); - if (x_value < 10) { - ASSERT_EQ(x_buffer[0], x_value + 1); - } else { - ASSERT_EQ(x_buffer[0], x_value); - } - } -} - -// if (x < 10) { -// x = x + 1; -// } else { -// x = x - 1; -// } -TEST(LLVM, CondTest) { - BufHandle x("X", {1}, kInt); - auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); - auto cond = - Cond::make(cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1)); - auto block = Block::make({ - cond, - x.store({0}, x.load(0) * 2), - }); - - for (int32_t x_value : {0, 10, 20}) { - std::vector x_buffer = {x_value}; - std::vector args({x_buffer.data()}); - LLVMCodeGen cg(block, {x}); - ASSERT_EQ(cg.value(args), 0); - if (x_value < 10) { - ASSERT_EQ(x_buffer[0], (x_value + 1) * 2); - } else { - ASSERT_EQ(x_buffer[0], (x_value - 1) * 2); - } - } -} - -// if (x < 10) { -// if (x > 5) { -// x = x + 1; -// } else { -// x = x - 1; -// } -// } else { -// if (x <= 15) { -// x = x + 2; -// } else { -// x = x - 2; -// } -// } -TEST(LLVM, CondNestedTest) { - BufHandle x("X", {1}, kInt); - auto true_cmp = - CompareSelect::make(x.load(0), 5, CompareSelectOperation::kGT); - auto true_cond = Cond::make( - true_cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1)); - auto false_cmp = - CompareSelect::make(x.load(0), 15, CompareSelectOperation::kLE); - auto false_cond = Cond::make( - false_cmp, x.store({0}, x.load(0) + 2), x.store({0}, x.load(0) - 2)); - auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); - auto cond = Cond::make(cmp, true_cond, false_cond); - - for (int32_t x_value : {0, 8, 15, 20}) { - std::vector x_buffer = {x_value}; - std::vector args({x_buffer.data()}); - LLVMCodeGen cg(cond, {x}); - ASSERT_EQ(cg.value(args), 0); - if (x_value < 10) { - if (x_value > 5) { - ASSERT_EQ(x_buffer[0], x_value + 1); - } else { - ASSERT_EQ(x_buffer[0], x_value - 1); - } - } else { - if (x_value <= 15) { - ASSERT_EQ(x_buffer[0], x_value + 2); - } else { - ASSERT_EQ(x_buffer[0], x_value - 2); - } - } - } -} - -TEST(LLVM, DirectVectorization) { - constexpr int M = 3; - constexpr int N = 64; - BufHandle a("a", {M, N}, kFloat); - BufHandle b("b", {M, N}, kFloat); - BufHandle c("c", {M, N}, kFloat); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - StmtPtr s = For::make( - m, - 0, - M, - Store::make( - c, - {Ramp::make(m * 64, 1, 64)}, - Load::make({kFloat, 64}, a, {Ramp::make(m * 64, 1, 64)}) * - Load::make({kFloat, 64}, b, {Ramp::make(m * 64, 1, 64)}))); - LLVMCodeGen cg(s, {a, b, c}); -} - -TEST(LLVM, VecLoadStoreTest) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - std::vector a_buffer = {1, 1, 1, 1}; - std::vector b_buffer = {2, 2, 2, 2}; - - auto store = b.store({Ramp::make(0, 1, 4)}, a.load({Ramp::make(0, 1, 4)})); - LLVMCodeGen cg(store, {a, b}); - std::vector args({a_buffer.data(), b_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - ASSERT_EQ(a_buffer[0], 1); - ASSERT_EQ(a_buffer[1], 1); - ASSERT_EQ(a_buffer[2], 1); - ASSERT_EQ(a_buffer[3], 1); - ASSERT_EQ(b_buffer[0], 1); - ASSERT_EQ(b_buffer[1], 1); - ASSERT_EQ(b_buffer[2], 1); - ASSERT_EQ(b_buffer[3], 1); -} - -#define FLOAT_INTRINSICS_TEST(Name, Lanes) \ - TEST(LLVM, VecFloat_##Name##Lane##Lanes##Test) { \ - BufHandle a("A", {1}, kFloat); \ - BufHandle b("B", {1}, kFloat); \ - float val = 0.5f; \ - std::vector a_buffer(Lanes, val); \ - std::vector b_buffer(Lanes, val); \ - auto store = b.store( \ - {Ramp::make(0, 1, Lanes)}, Name(a.load({Ramp::make(0, 1, Lanes)}))); \ - LLVMCodeGen cg(store, {a, b}); \ - std::vector args({a_buffer.data(), b_buffer.data()}); \ - ASSERT_EQ(cg.value(args), 0); \ - for (const auto i : c10::irange(Lanes)) { \ - ASSERT_FLOAT_EQ(a_buffer[i], val); \ - } \ - } // namespace jit -FLOAT_INTRINSICS_TEST(erf, 4) -FLOAT_INTRINSICS_TEST(erfc, 4) -FLOAT_INTRINSICS_TEST(acos, 4) -FLOAT_INTRINSICS_TEST(asin, 4) -FLOAT_INTRINSICS_TEST(atan, 4) -FLOAT_INTRINSICS_TEST(cosh, 4) -FLOAT_INTRINSICS_TEST(sinh, 4) -FLOAT_INTRINSICS_TEST(tanh, 4) -FLOAT_INTRINSICS_TEST(expm1, 4) -FLOAT_INTRINSICS_TEST(lgamma, 4) -FLOAT_INTRINSICS_TEST(erf, 8) -FLOAT_INTRINSICS_TEST(erfc, 8) -FLOAT_INTRINSICS_TEST(acos, 8) -FLOAT_INTRINSICS_TEST(asin, 8) -FLOAT_INTRINSICS_TEST(atan, 8) -FLOAT_INTRINSICS_TEST(cosh, 8) -FLOAT_INTRINSICS_TEST(sinh, 8) -FLOAT_INTRINSICS_TEST(tanh, 8) -FLOAT_INTRINSICS_TEST(expm1, 8) -FLOAT_INTRINSICS_TEST(lgamma, 8) -#undef FLOAT_INTRINSICS_TEST - -#define DOUBLE_INTRINSICS_TEST(Name, Lanes) \ - TEST(LLVM, VecDouble_##Name##Lane##Lanes##Test) { \ - BufHandle a("A", {1}, kDouble); \ - BufHandle b("B", {1}, kDouble); \ - float val = 0.5f; \ - std::vector a_buffer(Lanes, val); \ - std::vector b_buffer(Lanes, val); \ - auto store = b.store( \ - {Ramp::make(0, 1, Lanes)}, Name(a.load({Ramp::make(0, 1, Lanes)}))); \ - LLVMCodeGen cg(store, {a, b}); \ - std::vector args({a_buffer.data(), b_buffer.data()}); \ - ASSERT_EQ(cg.value(args), 0); \ - for (const auto i : c10::irange(Lanes)) { \ - ASSERT_FLOAT_EQ(a_buffer[i], val); \ - } \ - } // namespace jit -DOUBLE_INTRINSICS_TEST(erf, 2) -DOUBLE_INTRINSICS_TEST(erfc, 2) -DOUBLE_INTRINSICS_TEST(acos, 2) -DOUBLE_INTRINSICS_TEST(asin, 2) -DOUBLE_INTRINSICS_TEST(atan, 2) -DOUBLE_INTRINSICS_TEST(cosh, 2) -DOUBLE_INTRINSICS_TEST(sinh, 2) -DOUBLE_INTRINSICS_TEST(tanh, 2) -DOUBLE_INTRINSICS_TEST(expm1, 2) -DOUBLE_INTRINSICS_TEST(lgamma, 2) -DOUBLE_INTRINSICS_TEST(erf, 4) -DOUBLE_INTRINSICS_TEST(erfc, 4) -DOUBLE_INTRINSICS_TEST(acos, 4) -DOUBLE_INTRINSICS_TEST(asin, 4) -DOUBLE_INTRINSICS_TEST(atan, 4) -DOUBLE_INTRINSICS_TEST(cosh, 4) -DOUBLE_INTRINSICS_TEST(sinh, 4) -DOUBLE_INTRINSICS_TEST(tanh, 4) -DOUBLE_INTRINSICS_TEST(expm1, 4) -DOUBLE_INTRINSICS_TEST(lgamma, 4) -#undef DOUBLE_INTRINSICS_TEST - -TEST(LLVM, VectorizerLoadStoreTest) { - BufHandle a("A", {1}, kInt); - - Tensor c = Compute("c", {4}, [&](const VarHandle& i) { return a.load(i); }); - - BufHandle c_buf(c.buf()); - LoopNest l({c}); - StmtPtr s = l.root_stmt(); - ASSERT_TRUE(LoopNest::vectorize(to(to(s)->front()))); - - ASSERT_TRUE(to(to(s)->front()) == nullptr); - - LLVMCodeGen cg(s, {a, c_buf}); - - std::vector a_vec(4, 21); - std::vector c_vec(4, 0); - std::vector args({a_vec.data(), c_vec.data()}); - ASSERT_EQ(cg.value(args), 0); - assertAllEqual(c_vec, 21); -} - -TEST(LLVM, VectorizeBitCast) { - BufHandle a("A", {128}, kInt); - - Tensor c = Compute("c", {128}, [&](const VarHandle& i) { - return bitcast(a.load(i)); - }); - - BufHandle c_buf(c.buf()); - LoopNest l({c}); - StmtPtr s = l.root_stmt(); - ASSERT_TRUE(LoopNest::vectorize(to(to(s)->front()))); - ASSERT_TRUE(to(to(s)->front()) == nullptr); - - LLVMCodeGen cg(s, {a, c_buf}); - - std::vector a_vec(128); - std::vector c_vec(128); - for (const auto i : c10::irange(128)) { - a_vec[i] = raw_bitcast(1337.f); - } - std::vector args({a_vec.data(), c_vec.data()}); - ASSERT_EQ(cg.value(args), 0); - assertAllEqual(c_vec, 1337.f); -} - -TEST(LLVM, MemcpyTest) { - constexpr int N = 32; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - std::vector a_buffer(N, 42); - std::vector b_buffer(N, 0); - - VarHandle i("i", kInt); - auto expr = For::make(i, 0, N, b.store({i}, a.load(i))); - - LLVMCodeGen cg(expr, {a, b}); - - std::vector args({a_buffer.data(), b_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - assertAllEqual(a_buffer, 42); - assertAllEqual(b_buffer, 42); -} - -TEST(LLVM, BzeroTest) { - constexpr int N = 32; - BufHandle b("B", {N}, kInt); - std::vector b_buffer(N, 11); - - VarHandle i("i", kInt); - auto expr = For::make(i, 0, N, b.store({i}, 0)); - - LLVMCodeGen cg(expr, {b}); - - std::vector args({b_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(b_buffer.size(), N); - assertAllEqual(b_buffer, 0); -} - -TEST(LLVM, ElemwiseAdd) { - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 41); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 1); - - VarHandle i("i", kInt); - auto expr = For::make(i, 0, N, c.store({i}, Add::make(a.load(i), b.load(i)))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - assertAllEqual(a_buffer, 41); - assertAllEqual(b_buffer, 1); - assertAllEqual(c_buffer, 42); -} - -TEST(LLVM, ElemwiseAddFloat) { - constexpr int N = 1024; - BufHandle a("A", {N}, kFloat); - BufHandle b("B", {N}, kFloat); - BufHandle c("C", {N}, kFloat); - std::vector a_buffer(N, 41); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 1); - - VarHandle i("i", kInt); - auto expr = For::make(i, 0, N, c.store({i}, a.load(i) + b.load(i))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - assertAllEqual(a_buffer, 41.0f); - assertAllEqual(b_buffer, 1.0f); - assertAllEqual(c_buffer, 42.0f); -} - -TEST(LLVM, ElemwiseLog10Float) { - constexpr int N = 1024; - BufHandle a("A", {N}, kFloat); - BufHandle b("B", {N}, kFloat); - std::vector a_buffer(N, 10.0f); - std::vector b_buffer(N, 2.0f); - - VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N / 4, - b.store( - {Ramp::make(i * 4, 1, 4)}, log10(a.load({Ramp::make(i * 4, 1, 4)})))); - - LLVMCodeGen cg(expr, {a, b}); - - std::vector args({a_buffer.data(), b_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - assertAllEqual(a_buffer, 10.0f); - assertAllEqual(b_buffer, 1.0f); -} - -TEST(LLVM, ElemwiseLog1pFloat) { - constexpr int N = 1024; - BufHandle a("A", {N}, kFloat); - BufHandle b("B", {N}, kFloat); - std::vector a_buffer(N, expf(3.0f) - 1); - std::vector b_buffer(N, 42.0f); - - VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N / 4, - b.store( - {Ramp::make(i * 4, 1, 4)}, log1p(a.load({Ramp::make(i * 4, 1, 4)})))); - - LLVMCodeGen cg(expr, {a, b}); - - std::vector args({a_buffer.data(), b_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - assertAllEqual(a_buffer, expf(3.0f) - 1); - ExpectAllNear(b_buffer, 3.0f, 1e-5f); -} - -TEST(LLVM, ElemwiseMaxInt) { - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 41); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 1); - - VarHandle i("i", kInt); - auto expr = - For::make(i, 0, N, c.store({i}, Max::make(a.load(i), b.load(i), false))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - assertAllEqual(a_buffer, 41); - assertAllEqual(b_buffer, 1); - assertAllEqual(c_buffer, 41); -} - -TEST(LLVM, ElemwiseMinInt) { - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 41); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 1); - - VarHandle i("i", kInt); - auto expr = - For::make(i, 0, N, c.store({i}, Min::make(a.load(i), b.load(i), false))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - assertAllEqual(a_buffer, 41); - assertAllEqual(b_buffer, 1); - assertAllEqual(c_buffer, 1); -} - -TEST(LLVM, ElemwiseMaxFloat) { - constexpr int N = 1024; - BufHandle a("A", {N}, kFloat); - BufHandle b("B", {N}, kFloat); - BufHandle c("C", {N}, kFloat); - std::vector a_buffer(N, 41); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 1); - - VarHandle i("i", kInt); - auto expr = - For::make(i, 0, N, c.store({i}, Max::make(a.load(i), b.load(i), false))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - assertAllEqual(a_buffer, 41.0f); - assertAllEqual(b_buffer, 1.0f); - assertAllEqual(c_buffer, 41.0f); -} - -TEST(LLVM, ElemwiseMaxNaNFloat) { - constexpr int N = 1024; - BufHandle a("A", {N}, kFloat); - BufHandle b("B", {N}, kFloat); - BufHandle c("C", {N}, kFloat); - std::vector a_buffer(N, NAN); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 1); - - VarHandle i("i", kInt); - auto expr = - For::make(i, 0, N, c.store({i}, Max::make(a.load(i), b.load(i), false))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - assertAllEqual(b_buffer, 1.0f); - for (auto const& elt : c_buffer) { - ASSERT_TRUE(std::isnan(elt)); - } -} - -TEST(LLVM, ElemwiseMinFloat) { - constexpr int N = 1024; - BufHandle a("A", {N}, kFloat); - BufHandle b("B", {N}, kFloat); - BufHandle c("C", {N}, kFloat); - std::vector a_buffer(N, 41); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 1); - - VarHandle i("i", kInt); - auto expr = - For::make(i, 0, N, c.store({i}, Min::make(a.load(i), b.load(i), false))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - assertAllEqual(a_buffer, 41.0f); - assertAllEqual(b_buffer, 1.0f); - assertAllEqual(c_buffer, 1.0f); -} - -TEST(LLVM, ElemwiseMinNaNFloat) { - constexpr int N = 1024; - BufHandle a("A", {N}, kFloat); - BufHandle b("B", {N}, kFloat); - BufHandle c("C", {N}, kFloat); - std::vector a_buffer(N, NAN); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 1); - - VarHandle i("i", kInt); - auto expr = - For::make(i, 0, N, c.store({i}, Min::make(a.load(i), b.load(i), false))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - assertAllEqual(b_buffer, 1.0f); - for (auto const& elt : c_buffer) { - ASSERT_TRUE(std::isnan(elt)); - } -} - -TEST(LLVM, ElemwiseMod) { - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 41); - std::vector b_buffer(N, 23); - std::vector c_buffer(N, 18); - - VarHandle i("i", kInt); - auto expr = For::make(i, 0, N, c.store({i}, Mod::make(a.load(i), b.load(i)))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - assertAllEqual(a_buffer, 41); - assertAllEqual(b_buffer, 23); - assertAllEqual(c_buffer, 18); -} - -TEST(LLVM, CompareSelectIntEQ) { - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 1); - std::vector b_buffer(N, 1); - std::vector c_buffer(N, 0); - std::vector c_ref(N, 1); - - for (int i = 0; i < N / 2; i++) { - b_buffer[i] = 0; - c_ref[i] = 0; - } - - VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kEQ))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - - assertAllEqual(a_buffer, 1); - for (const auto i : c10::irange(N)) { - ASSERT_EQ(c_ref[i], c_buffer[i]); - } -} - -TEST(LLVM, CompareSelectFloatEQ) { - constexpr int N = 1024; - BufHandle a("A", {N}, kFloat); - BufHandle b("B", {N}, kFloat); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 1.0f); - std::vector b_buffer(N, 1.0f); - std::vector c_buffer(N, 0); - - VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kEQ))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - - assertAllEqual(a_buffer, 1.0f); - assertAllEqual(b_buffer, 1.0f); - assertAllEqual(c_buffer, 1); -} - -TEST(LLVM, CompareSelectByteGT) { - constexpr int N = 1024; - BufHandle a("A", {N}, kByte); - BufHandle b("B", {N}, kByte); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 0); - std::vector b_buffer(N, 0); - std::vector c_buffer(N, 0); - std::vector c_ref(N, 0); - - for (int i = 0; i < N / 2; i++) { - a_buffer[i] = 128; - c_ref[i] = 1; - } - - VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kGT))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - - assertAllEqual(b_buffer, uint8_t(0)); - for (const auto i : c10::irange(N)) { - ASSERT_EQ(c_ref[i], c_buffer[i]); - } -} - -TEST(LLVM, CompareSelectByteGE) { - constexpr int N = 1024; - BufHandle a("A", {N}, kByte); - BufHandle b("B", {N}, kByte); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 0); - std::vector b_buffer(N, 0); - std::vector c_buffer(N, 0); - std::vector c_ref(N, 1); - - VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kGE))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - - assertAllEqual(b_buffer, uint8_t(0)); - for (const auto i : c10::irange(N)) { - ASSERT_EQ(c_ref[i], c_buffer[i]); - } -} - -TEST(LLVM, CompareSelectByteLT) { - constexpr int N = 1024; - BufHandle a("A", {N}, kByte); - BufHandle b("B", {N}, kByte); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 0); - std::vector b_buffer(N, 128); - std::vector c_buffer(N, 0); - std::vector c_ref(N, 1); - - for (int i = 0; i < N / 2; i++) { - a_buffer[i] = 128; - c_ref[i] = 0; - } - - VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kLT))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - - assertAllEqual(b_buffer, uint8_t(128)); - for (const auto i : c10::irange(N)) { - ASSERT_EQ(c_ref[i], c_buffer[i]); - } -} - -TEST(LLVM, CompareSelectByteLE) { - constexpr int N = 1024; - BufHandle a("A", {N}, kByte); - BufHandle b("B", {N}, kByte); - BufHandle c("C", {N}, kInt); - std::vector a_buffer(N, 0); - std::vector b_buffer(N, 128); - std::vector c_buffer(N, 0); - std::vector c_ref(N, 1); - - VarHandle i("i", kInt); - auto expr = For::make( - i, - 0, - N, - c.store( - {i}, - CompareSelect::make( - a.load(i), b.load(i), CompareSelectOperation::kLE))); - - LLVMCodeGen cg(expr, {a, b, c}); - - std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - - ASSERT_EQ(a_buffer.size(), N); - ASSERT_EQ(b_buffer.size(), N); - ASSERT_EQ(c_buffer.size(), N); - - assertAllEqual(b_buffer, uint8_t(128)); - for (const auto i : c10::irange(N)) { - ASSERT_EQ(c_ref[i], c_buffer[i]); - } -} - -TEST(LLVM, StoreFloat) { - BufHandle result("result", {1}, kFloat); - std::vector result_buffer = {0.0f}; - auto expr = result.store({0}, FloatImm::make(3.14f)); - LLVMCodeGen cg(expr, {result}); - std::vector args({result_buffer.data()}); - ASSERT_EQ(cg.value(args), 0); - ASSERT_EQ(result_buffer[0], 3.14f); -} - -TEST(LLVM, SimpleMath01) { - const int N = 1024; - Tensor tensor = Compute( - "f", {N}, [](const VarHandle& i) { return cast(i * i + 1); }); - LoopNest l({tensor}); - StmtPtr stmt = l.root_stmt(); - BufHandle f_buf(tensor.buf()); - LLVMCodeGen cg(stmt, {f_buf}); - - PaddedBuffer f_v(N, "f_v"); - std::vector args({f_v.data()}); - int value = cg.value(args); - ASSERT_EQ(value, 0); - PaddedBuffer f_ref(N, "f_ref"); - for (const auto i : c10::irange(N)) { - f_ref(i) = i * i + 1; - } - ExpectAllNear(f_v, f_ref, 1e-5); -} - -TEST(LLVM, ComputeMul) { - const int N = 1024; - BufHandle a("a", {N}, kFloat); - BufHandle b("b", {N}, kFloat); - Tensor c = Compute( - "c", {N}, [&](const VarHandle& i) { return a.load(i) * b.load(i); }); - - BufHandle c_buf(c.buf()); - LoopNest l({c}); - StmtPtr s = l.root_stmt(); - - LLVMCodeGen cg(s, {a, b, c_buf}); - - std::vector a_vec(N, 21.0f); - std::vector b_vec(N, 2.0f); - std::vector c_vec(N, 0.0f); - std::vector args({a_vec.data(), b_vec.data(), c_vec.data()}); - ASSERT_EQ(cg.value(args), 0); - assertAllEqual(c_vec, 42.0f); -} - -TEST(LLVM, BroadcastAdd) { - const int M = 32; - const int N = 1024; - BufHandle a("a", {M, N}, kFloat); - BufHandle b("b", {N}, kFloat); - Tensor c = Compute("c", {M, N}, [&](const VarHandle& i, const VarHandle& j) { - return a.load(i, j) + b.load(j); - }); - - BufHandle c_buf(c.buf()); - LoopNest l({c}); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - - LLVMCodeGen cg(s, {a, b, c_buf}); - - std::vector av(M * N); - std::iota(av.begin(), av.end(), 0); - std::vector bv(N); - std::iota(bv.begin(), bv.end(), 0); - std::vector cv(M * N, 0); - std::vector args({av.data(), bv.data(), cv.data()}); - ASSERT_EQ(cg.value(args), 0); - - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - ASSERT_EQ(cv[i * N + j], av[i * N + j] + bv[j]); - } - } -} - -TEST(LLVM, BitwiseOps) { - auto a = IntImm::make(59); - auto b = IntImm::make(11); - auto c = IntImm::make(101); - auto d = IntImm::make(2); - - ExprHandle f = (((a ^ (b << 1)) & c) >> 2) | d; - LLVMExprEval cg(f); - - ASSERT_EQ(cg.value(), 11); -} - -TEST(LLVM, ArithmeticRightShift) { - auto a = CharImm::make(-4); - auto b = CharImm::make(1); - ExprHandle f = a >> b; - LLVMExprEval cg(f); - ASSERT_EQ(cg.value(), -2); -} - -TEST(LLVM, LogicalRightShift) { - auto a = ByteImm::make(0xfc); - auto b = ByteImm::make(1); - ExprHandle f = a >> b; - LLVMExprEval cg(f); - ASSERT_EQ(cg.value(), 0x7e); -} - -TEST(LLVM, DynamicShapeAdd) { - auto testWithSize = [](int32_t size) { - VarHandle n("n", kInt); - BufHandle a("a", {n}, kFloat); - BufHandle b("b", {n}, kFloat); - BufHandle c("c", {n}, kFloat); - VarHandle i("i", kInt); - StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i))); - std::vector aData(size, 1.0f); - std::vector bData(size, 2.0f); - std::vector cData(size, 0.0f); - LLVMCodeGen cg(s, {a, b, c, n}); - std::vector args({aData.data(), bData.data(), cData.data(), &size}); - cg.value(args); - ExpectAllNear(cData, std::vector(size, 3.0f), 1e-7); - }; - testWithSize(1); - testWithSize(16); - testWithSize(37); -} - -TEST(LLVM, BindDynamicShapeAdd) { - auto testWithSize = [](int32_t size) { - VarHandle n("n", kInt); - BufHandle a("a", {n}, kFloat); - BufHandle b("b", {n}, kFloat); - BufHandle c("c", {n}, kFloat); - VarHandle i("i", kInt); - StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i))); - std::vector aData(size, 1.0f); - std::vector bData(size, 2.0f); - std::vector cData(size, 0.0f); - LLVMCodeGen cg(s, {a, b, c, n}); - cg.call({aData, bData, cData, size}); - ExpectAllNear(cData, std::vector(size, 3.0f), 1e-7); - }; - testWithSize(1); - testWithSize(16); - testWithSize(37); -} - -TEST(LLVM, TensorDynamicShapeAdd) { - auto testWithSize = [](int32_t size) { - VarHandle n("n", kInt); - BufHandle a("a", {n}, kFloat); - BufHandle b("b", {n}, kFloat); - Tensor c = Compute( - "c", {n}, [&](const VarHandle& i) { return a.load(i) + b.load(i); }); - LoopNest l({c}); - StmtPtr s = l.root_stmt(); - LLVMCodeGen cg(s, {a, b, c, n}); - std::vector aData(size, 1.0f); - std::vector bData(size, 2.0f); - std::vector cData(size, 0.0f); - cg.call({aData, bData, cData, size}); - ExpectAllNear(cData, std::vector(size, 3.0f), 1e-7); - }; - testWithSize(1); - testWithSize(16); - testWithSize(37); -} - -TEST(LLVM, DynamicShape2D) { - auto testWithSize = [](int32_t M, int32_t N) { - VarHandle m("m", kInt); - VarHandle n("n", kInt); - BufHandle a("a", {m, n}, kFloat); - BufHandle b("b", {m, n}, kFloat); - Tensor c = - Compute("c", {m, n}, [&](const VarHandle& i, const VarHandle& j) { - return a.load(i, j) + b.load(i, j); - }); - LoopNest l({c}); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - LLVMCodeGen cg(s, {a, b, c, m, n}); - std::vector aData(M * N, 1.0f); - std::vector bData(M * N, 2.0f); - std::vector cData(M * N, 0.0f); - cg.call({aData, bData, cData, M, N}); - ExpectAllNear(cData, std::vector(M * N, 3.0f), 1e-7); - }; - testWithSize(1, 8); - testWithSize(16, 32); - testWithSize(37, 11); -} - -TEST(LLVM, EmptyStmt) { - StmtPtr s = alloc(std::vector({})); - - LLVMCodeGen cg(s, {}); - cg.call({}); - // Just don't crash. -} - -TEST(LLVM, EliminatedStmt) { - BufHandle a("a", {1}, kFloat); - - Tensor c = Compute("c", {0}, [&](const VarHandle& m) { return m; }); - - LoopNest l({c}); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - s = IRSimplifier::simplify(s); - LLVMCodeGen cg(s, {a, c}); - std::vector aData(1, 1.0f); - std::vector cData(0, 0.0f); - cg.call({aData, cData}); -} - -TEST(LLVM, SimpleReduction) { - int M = 128; - int N = 64; - - BufHandle a("a", {1, M, N}, kFloat); - - Tensor b = Reduce("sum", {1}, Sum(), a, {M, N}); - LoopNest loop({b}); - - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - LLVMCodeGen cg(s, {a, b}); - - PaddedBuffer a_v(1, M, N, "a_v"); - PaddedBuffer b_v(1, "b_v"); - PaddedBuffer b_ref(1, "b_ref"); - - b_ref(0) = 0; - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - int v = i + j; - a_v(0, i, j) = v; - b_ref(0) += v; - } - } - - cg.call({a_v, b_v}); - - ExpectAllNear(b_v, b_ref, 1e-5); -} - -TEST(LLVM, RFactorReduction) { - int M = 128; - int N = 64; - - BufHandle a("a", {1, M, N}, kFloat); - - Tensor b = Reduce("sum", {1}, Sum(), a, {M, N}); - LoopNest loop({b}); - - std::vector loops = loop.getLoopStmtsFor(b); - ForPtr loop_m = loops.at(1); - ForPtr loop_n = loops.at(2); - loop.reorderAxis(loop_m, loop_n); - - loops = loop.getLoopStmtsFor(b); - loop_m = loops.at(2); - loop_n = loops.at(1); - auto b_body = loop.getAllWritesToBuf(b.buf())[1]; - ASSERT_TRUE(loop.rfactor(b_body, loop_n)); - - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - LLVMCodeGen cg(s, {a, b}); - - PaddedBuffer a_v(1, M, N, "a_v"); - PaddedBuffer b_v(1, "b_v"); - PaddedBuffer b_ref(1, "b_ref"); - - b_ref(0) = 0; - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - int v = i + j; - a_v(0, i, j) = v; - b_ref(0) += v; - } - } - - cg.call({a_v, b_v}); - - ExpectAllNear(b_v, b_ref, 1e-5); -} - -TEST(LLVM, RFactorVectorizedReduction) { - int M = 128; - int N = 64; - - BufHandle a("a", {1, M, N}, kFloat); - - Tensor b = Reduce("sum", {1}, Sum(), a, {M, N}); - LoopNest loopnest({b}); - std::vector loops = loopnest.getLoopStmtsFor(b); - // Reorder n and m loops - loopnest.reorderAxis(loops.at(1), loops.at(2)); - auto b_body = loopnest.getAllWritesToBuf(b.buf()).at(1); - auto all_loops = loopnest.getAllLoopNestsWritingToBuf(b.buf()); - ASSERT_TRUE(all_loops.size() == 2 && all_loops[1].size() == 3); - ASSERT_TRUE(loopnest.rfactor(b_body, all_loops[1][1])); - auto distributed_loops = loopnest.distributeLoop(all_loops[1][1]); - - // Vectorize initializer of rfac_buf - ASSERT_TRUE(LoopNest::vectorize(distributed_loops[0])); - // Vectorize producer of rfac_buf - ASSERT_TRUE(LoopNest::vectorize(distributed_loops[1])); - loopnest.simplify(); - - loopnest.prepareForCodegen(); - - StmtPtr s = IRSimplifier::simplify(loopnest.root_stmt()); - LLVMCodeGen cg(s, {a, b}); - - PaddedBuffer a_v(1, M, N, "a_v"); - PaddedBuffer b_v(1, "b_v"); - PaddedBuffer b_ref(1, "b_ref"); - - b_ref(0) = 0; - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - int v = i + j; - a_v(0, i, j) = v; - b_ref(0) += v; - } - } - - cg.call({a_v, b_v}); - - ExpectAllNear(b_v, b_ref, 1e-5); -} - -template -static void testSimpleParallel() { - // Compute a simple operation, and try all loop-axis combination to be - // parallel or sequential. - const int M = 4; - const int N = 6; - Tensor f = Compute("f", {M, N}, [](const VarHandle& m, const VarHandle& n) { - return cast(m + n); - }); - LoopNest loop_nest({f}); - auto const& loops = loop_nest.getLoopStmtsFor(f); - ForPtr m = loops[0]; - ForPtr n = loops[1]; - if (outer) { - m->set_parallel(); - } - if (inner) { - n->set_parallel(); - } - loop_nest.prepareForCodegen(); - StmtPtr stmt = loop_nest.root_stmt(); - LLVMCodeGen cg(stmt, {f}); - - PaddedBuffer f_v(M, N, "f_v"); - std::vector args({f_v.data()}); - int value = cg.value(args); - ASSERT_EQ(value, 0); - PaddedBuffer f_ref(M, N, "f_ref"); - for (const auto m : c10::irange(M)) { - for (const auto n : c10::irange(N)) { - f_ref(m, n) = m + n; - } - } - ExpectAllNear(f_v, f_ref, 1e-5); -} - -TEST(LLVM, SimpleParallelSS) { - testSimpleParallel(); -} -TEST(LLVM, SimpleParallelSP) { - testSimpleParallel(); -} -TEST(LLVM, SimpleParallelPS) { - testSimpleParallel(); -} -TEST(LLVM, SimpleParallelPP) { - testSimpleParallel(); -} - -TEST(LLVM, CompositeParallel) { - int loop_count = 6; - int test_count = 1 << loop_count; - // Compute a composite operation, and try all loop-axis combination to be - // parallel or sequential. - for (const auto test_cfg : c10::irange(test_count)) { - int M = 5; - int N = 7; - Tensor t1 = Compute("t1", {M}, [](const VarHandle& m) { return m + 1.f; }); - Tensor t2 = Compute("t2", {N}, [](const VarHandle& n) { return n + 2.f; }); - Tensor t3 = - Compute("t3", {M, N}, [=](const VarHandle& m, const VarHandle& n) { - return t1.load(m) * t2.load(n); - }); - Tensor t4 = - Compute("t4", {M, N}, [=](const VarHandle& m, const VarHandle& n) { - return t3.load(m, n) + m + n; - }); - LoopNest loop_nest({t4}, {t1, t2, t3, t4}); - std::vector loop_list; - { - auto const& loops = loop_nest.getLoopStmtsFor(t1); - loop_list.push_back(loops[0]); - } - { - auto const& loops = loop_nest.getLoopStmtsFor(t2); - loop_list.push_back(loops[0]); - } - { - auto const& loops = loop_nest.getLoopStmtsFor(t3); - loop_list.push_back(loops[0]); - loop_list.push_back(loops[1]); - } - { - auto const& loops = loop_nest.getLoopStmtsFor(t4); - loop_list.push_back(loops[0]); - loop_list.push_back(loops[1]); - } - ASSERT_EQ(loop_list.size(), loop_count); - for (const auto i : c10::irange(loop_count)) { - if (test_cfg & (1 << i)) { - loop_list[i]->set_parallel(); - } - } - loop_nest.prepareForCodegen(); - StmtPtr stmt = loop_nest.root_stmt(); - LLVMCodeGen cg(stmt, {t4}); - - PaddedBuffer t4_v(M, N, "t4_v"); - std::vector args({t4_v.data()}); - int value = cg.value(args); - ASSERT_EQ(value, 0); - PaddedBuffer t4_ref(M, N, "t4_ref"); - for (const auto m : c10::irange(M)) { - for (const auto n : c10::irange(N)) { - t4_ref(m, n) = (m + 1) * (n + 2) + m + n; - } - } - ExpectAllNear(t4_v, t4_ref, 1e-5); - } -} - -TEST(LLVM, VectorizedGEMM) { - int M = 32; - int N = 32; - int K = 48; - - BufHandle AP("A", {M, K}, kFloat); - BufHandle BP("B", {K, N}, kFloat); - Tensor CT = Reduce( - "gemm", - {M, N}, - Sum(), - [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { - return AP.load(m, k) * BP.load(k, n); - }, - {K}); - LoopNest loop({CT}); - - { - auto const& loops = loop.getLoopStmtsFor(CT); - ForPtr m = loops[0]; - loop.splitWithMask(m, 16); - } - { - auto const& loops = loop.getLoopStmtsFor(CT); - ForPtr n = loops[2]; - loop.splitWithMask(n, 16); - } - // mo, mi, no, ni, k -> - // mo, no, mi, ni, k - { - auto const& loops = loop.getLoopStmtsFor(CT); - ForPtr mi = loops[1]; - ForPtr no = loops[2]; - loop.reorderAxis(mi, no); - } - // mo, no, mi, ni, k -> - // mo, no, mi, k, ni - { - auto const& loops = loop.getLoopStmtsFor(CT); - ForPtr ni = loops[3]; - ForPtr k = loops[4]; - loop.reorderAxis(ni, k); - } - // mo, no, mi, k, ni -> - // mo, no, k, mi, ni - { - auto const& loops = loop.getLoopStmtsFor(CT); - ForPtr mi = loops[2]; - ForPtr k = loops[3]; - loop.reorderAxis(mi, k); - } - { - auto loops = NodeFinder::find(loop.root_stmt()); - ASSERT_TRUE(LoopNest::vectorize(loops[3])); - ASSERT_TRUE(LoopNest::vectorize(loops.back())); - } - - loop.prepareForCodegen(); - - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - LLVMCodeGen cg(s, {AP, BP, CT}); - - PaddedBuffer a_v(M, K, "a_v"); - PaddedBuffer b_v(K, N, "b_v"); - PaddedBuffer c_v(M, N, "c_v"); - PaddedBuffer c_ref(M, N, "c_ref"); - - for (const auto m : c10::irange(M)) { - for (const auto n : c10::irange(N)) { - c_ref(m, n) = 0.f; - for (const auto k : c10::irange(K)) { - c_ref(m, n) += a_v(m, k) * b_v(k, n); - } - } - } - - cg.call({a_v, b_v, c_v}); - - ExpectAllNear(c_v, c_ref, 1e-5); -} - -TEST(LLVM, CallRaw) { - const int M = 32; - VarHandle N("N", kInt); - BufHandle a("a", {M, N}, kFloat); - BufHandle b("b", {N}, kFloat); - Tensor c = Compute("c", {M, N}, [&](const VarHandle& i, const VarHandle& j) { - return a.load(i, j) + b.load(j); - }); - - LoopNest l({c}); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - - int32_t N_value = 1024; - std::vector av(M * N_value); - std::iota(av.begin(), av.end(), 0); - std::vector bv(N_value); - std::iota(bv.begin(), bv.end(), 0); - std::vector cv(M * N_value, 0); - std::vector args({av.data(), bv.data(), cv.data(), &N_value}); - - LLVMCodeGen cg(s, {a, b, BufHandle(c.buf()), N}); - cg.call_raw(args); - - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N_value)) { - ASSERT_EQ(cv[i * N_value + j], av[i * N_value + j] + bv[j]); - } - } - - SimpleIREvaluator eval(s, {a, b, BufHandle(c.buf()), N}); - eval.call_raw(args); - - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N_value)) { - ASSERT_EQ(cv[i * N_value + j], av[i * N_value + j] + bv[j]); - } - } -} - -TEST(LLVM, CustomTarget) { - constexpr int M = 16; - BufHandle a("a", {M}, kFloat); - BufHandle b("b", {M}, kFloat); - BufHandle c("c", {M}, kFloat); - Tensor d = Compute("d", {M}, [&](const VarHandle& m) { - return a.load(m) * b.load(m) + c.load(m); - }); - LoopNest nest({d}); - nest.prepareForCodegen(); - auto cg = LLVMCodeGenBuilder(nest.root_stmt(), {a, b, c, d}) - .triple("i686-elf") - .cpu("i386") - .build(); - std::ostringstream ss; - ss << cg->getCodeText("asm"); - torch::jit::testing::FileCheck() - .check("fadds") - ->check("fmuls") - ->check_not("vfmadd") - ->run(ss.str()); -} - -TEST(LLVM, CodeGenKernelFuncName) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - std::vector a_buffer = {42}; - std::vector b_buffer = {-11}; - auto store = b.store({0}, a.load(0)); - - { - LLVMCodeGen cg(store, {a, b}); - // Check that the kernel function name used by LLVMCodeGen - // is not empty. - ASSERT_NE(cg.kernel_func_name(), ""); - } - - { - LLVMCodeGen cg(store, {a, b}, at::kCPU, "new_func"); - // Check that the kernel function name used by LLVMCodeGen - // is the one that was given above. - ASSERT_EQ(cg.kernel_func_name(), "new_func"); - } -} - -} // namespace jit -} // namespace torch - -#endif // TORCH_ENABLE_LLVM diff --git a/test/cpp/tensorexpr/test_loopnest.cpp b/test/cpp/tensorexpr/test_loopnest.cpp deleted file mode 100644 index a8bda8814dba..000000000000 --- a/test/cpp/tensorexpr/test_loopnest.cpp +++ /dev/null @@ -1,6894 +0,0 @@ -#include - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -void checkIR(StmtPtr s, const std::string& pattern) { - std::ostringstream oss; - oss << *s; - torch::jit::testing::FileCheck().run(pattern, oss.str()); -} - -void checkExprIR(ExprPtr e, const std::string& pattern) { - std::string prefixed_pattern = "# CHECK: " + pattern + "\n"; - std::ostringstream oss; - oss << *e << "\n"; - torch::jit::testing::FileCheck().run(prefixed_pattern, oss.str()); -} - -void checkExprIR(const ExprHandle& e, const std::string& pattern) { - checkExprIR(e.node(), pattern); -} - -TEST(LoopNest, ExprSimple01) { - Tensor tensor = - Compute("f", {16, 5}, [](const VarHandle& x, const VarHandle& y) { - return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; - }); - LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - - LoopNest::splitWithTail(loops[0], 2); - LoopNest::splitWithTail(loops[0], 2); -} - -TEST(LoopNest, ExprLower01) { - Tensor tensor = - Compute("f", {16, 5}, [](const VarHandle& x, const VarHandle& y) { - return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; - }); - LoopNest l({tensor}); - StmtPtr stmt = l.root_stmt(); - std::ostringstream oss; - oss << *stmt; - ASSERT_GT(oss.str().size(), 20); - ASSERT_LT(oss.str().size(), 200); -} - -TEST(LoopNest, ExprSimple02) { - auto func = [](const ExprHandle& x, const ExprHandle& y) { - return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; - }; - Tensor tensor = Compute("f", {26, 5}, func); - LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - - LoopNest::splitWithTail(loops[0], 4); - - StmtPtr stmt = l.root_stmt(); - std::ostringstream oss; - oss << *stmt; - ASSERT_GT(oss.str().size(), 200); - ASSERT_LT(oss.str().size(), 600); - - { - // Compare to a reference loop structure structure. - VarHandle x_outer("i_outer", kInt); - VarHandle x_inner("i_inner", kInt); - VarHandle y("i", kInt); - VarHandle x_tail("i_tail", kInt); - BufHandle f("f", {26, 5}, kFloat); - ExprHandle x_1 = x_outer * 4 + x_inner; - ExprHandle x_outer_end = (ExprHandle(26) - 0) / 4; - ForPtr stmt1 = For::make( - x_outer, - 0, - x_outer_end, - For::make( - x_inner, - 0, - 4, - For::make(y, 0, 5, Store::make(f, {x_1, y}, func(x_1, y))))); - ExprHandle x_2 = x_tail + x_outer_end * 4; - ForPtr stmt2 = For::make( - x_tail, - 0, - (ExprHandle(26) - 0) % 4, - For::make(y, 0, 5, Store::make(f, {x_2, y}, func(x_2, y)))); - StmtPtr stmt = Block::make({stmt1, stmt2}); - - std::ostringstream oss_ref; - oss_ref << *stmt; - ASSERT_EQ(oss.str(), oss_ref.str()); - } - - { - PaddedBuffer f_v(26, 5, "f_v"); - PaddedBuffer f_ref(26, 5, "f_res"); - - stmt = FlattenIndexes(stmt); - SimpleIREvaluator ir_eval(stmt, {tensor}); - ir_eval(f_v); - - for (int x = 0; x < 26; x++) { - for (int y = 0; y < 5; y++) { - f_ref(x, y) = 1 + x * x + y * y; - } - } - - ExpectAllNear(f_v, f_ref, 1e-5); - } -} - -BlockPtr getSimplifiedBody(const LoopNest& l) { - StmtPtr stmt = l.root_stmt(); - StmtPtr simplified = IRSimplifier::simplify(stmt); - return to(simplified); -} - -void assertForRange(ForPtr f, int expected_start, int expected_stop) { - ASSERT_NE(f, nullptr); - IntImmPtr start = to(f->start()); - ASSERT_NE(start, nullptr); - ASSERT_EQ(start->value(), expected_start); - IntImmPtr stop = to(f->stop()); - ASSERT_NE(stop, nullptr); - ASSERT_EQ(stop->value(), expected_stop); -} - -void assertForRanges( - BlockPtr body, - const std::vector>& start_stops) { - ASSERT_EQ(body->nstmts(), start_stops.size()); - - auto it = body->begin(); - for (size_t i = 0; i < start_stops.size(); i++, it++) { - ForPtr loop = to(*it); - assertForRange(loop, start_stops[i].first, start_stops[i].second); - } -} - -TEST(LoopNest, ExprSliceHeadWithLoopOptions) { - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {10}, func); - LoopNest l({tensor}); - ForPtr head; - ForPtr tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - loops[0]->set_gpu_block_index(LoopOptions::IDX_Y); - LoopNest::sliceHead(loops[0], 2, &head, &tail); - - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 2}, {0, 8}}); - - ASSERT_TRUE(tail->loop_options().is_gpu_block_index()); - ASSERT_EQ(tail->loop_options().gpu_block_index(), LoopOptions::IDX_Y); - - ASSERT_TRUE(head->loop_options().isDefault()); -} - -TEST(LoopNest, ExprSliceTailWithLoopOptions) { - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {10}, func); - LoopNest l({tensor}); - ForPtr head; - ForPtr tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::sliceTail(loops[0], 4, &head, &tail); - - ForPtr tail_head; - ForPtr tail_tail; - tail->set_gpu_block_index(LoopOptions::IDX_Y); - LoopNest::sliceTail(tail, 2, &tail_head, &tail_tail); - - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 6}, {0, 2}, {8, 10}}); - - ASSERT_TRUE(tail_head->loop_options().is_gpu_block_index()); - ASSERT_EQ(tail_head->loop_options().gpu_block_index(), LoopOptions::IDX_Y); - - ASSERT_TRUE(head->loop_options().isDefault()); - ASSERT_TRUE(tail_tail->loop_options().isDefault()); -} - -TEST(LoopNest, ExprSliceHeadWhenFactorEqualsSize) { - // When factor equals the For loop's original size, keep using the original - // For loop. - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {10}, func); - LoopNest l({tensor}); - ForPtr head; - ForPtr tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::sliceHead(loops[0], 10, &head, &tail); - - ASSERT_EQ(head, loops[0]); - ASSERT_EQ(tail, nullptr); - - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 10}}); -} - -TEST(LoopNest, ExprSliceHeadWhenFactorLargerThanSize) { - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {10}, func); - LoopNest l({tensor}); - ForPtr head; - ForPtr tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::sliceHead(loops[0], 100, &head, &tail); - - ASSERT_EQ(head, loops[0]); - ASSERT_EQ(tail, nullptr); - - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 10}}); -} - -TEST(LoopNest, ExprSliceHead) { - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {10}, func); - LoopNest l({tensor}); - ForPtr head; - ForPtr tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::sliceHead(loops[0], 4, &head, &tail); - - ASSERT_NE(head, nullptr); - ASSERT_NE(head, loops[0]); - ASSERT_NE(tail, nullptr); - ASSERT_EQ(tail, loops[0]); - - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 4}, {4, 10}}); -} - -TEST(LoopNest, ExprSliceHeadWithNonZeroStart) { - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {10}, func); - LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - - ForPtr head; - ForPtr tail; - LoopNest::sliceTail(loops[0], 4, &head, &tail); - // head: [0, 6) - // tail: [6, 10) - - LoopNest::sliceHead(tail, 2); - // tail_head: [6, 8) - // tail_tail: [8, 10) - - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 6}, {6, 8}, {8, 10}}); -} - -TEST(LoopNest, ExprSliceTailWhenFactorEqualsSize) { - // When factor equals the For loop's original size, keep using the original - // For loop. - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {10}, func); - LoopNest l({tensor}); - ForPtr head; - ForPtr tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::sliceTail(loops[0], 10, &head, &tail); - - ASSERT_EQ(head, nullptr); - ASSERT_EQ(tail, loops[0]); - - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 10}}); -} - -TEST(LoopNest, ExprSliceTailWhenFactorLargerThanSize) { - // When factor equals the For loop's original size, keep using the original - // For loop. - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {10}, func); - LoopNest l({tensor}); - ForPtr head; - ForPtr tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::sliceTail(loops[0], 100, &head, &tail); - - ASSERT_EQ(head, nullptr); - ASSERT_EQ(tail, loops[0]); - - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 10}}); -} - -TEST(LoopNest, ExprSliceTail) { - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {10}, func); - LoopNest l({tensor}); - ForPtr head; - ForPtr tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::sliceTail(loops[0], 4, &head, &tail); - - ASSERT_NE(head, nullptr); - ASSERT_EQ(head, loops[0]); - ASSERT_NE(tail, nullptr); - ASSERT_NE(tail, loops[0]); - - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 6}, {6, 10}}); -} - -TEST(LoopNest, ExprSplitAndSlice) { - // 0: splitWithTail - // 1: sliceTail on inner loop - // 2: sliceHead on outer loop - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {100}, func); - LoopNest l({tensor}); - - ForPtr inner; - ForPtr tail; - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - // outer: [0, 4) - // inner: [0, 21) - // tail: [84, 100) - LoopNest::splitWithTail(loops[0], 21, &inner, &tail); - LoopNest::sliceTail(inner, 2); - LoopNest::sliceHead(loops[0], 2); - - // for (int x_outer = 0; x_outer < 2; x_outer++) { - // for (int x_inner = 0; x_inner < 19; x_inner++) { - // f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner); - // } - // for (int x_inner = 19; x_inner < 21; x_inner++) { - // f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner); - // } - // } - // for (int x_outer = 2; x_outer < 4; x_outer++) { - // for (int x_inner = 0; x_inner < 19; x_inner++) { - // f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner); - // } - // for (int x_inner = 19; x_inner < 21; x_inner++) { - // f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner); - // } - // } - // for (int x_tail = 0; x_tail < 16; x_tail++) { - // f[x_tail + 84] = 1.f + float(x_tail + 84); - // } - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 2}, {2, 4}, {0, 16}}); - - auto biter = body->begin(); - - ForPtr loop = to(*biter++); - assertForRanges(loop->body(), {{0, 19}, {19, 21}}); - - loop = to(*biter); - assertForRanges(loop->body(), {{0, 19}, {19, 21}}); -} - -TEST(LoopNest, ExprSliceAndNormalize) { - // 0: sliceHead - // 1: normalize tail - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {10}, func); - LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - - ForPtr head; - ForPtr tail; - LoopNest::sliceHead(loops[0], 2, &head, &tail); - // head: [0, 2) - // tail: [2, 10) - - LoopNest::normalize(tail); - // normalized_tail: [0, 8) - - BlockPtr body = getSimplifiedBody(l); - assertForRanges(body, {{0, 2}, {0, 8}}); -} - -template -T evalExpr(const ExprHandle& expr, const VarHandle& var, T value) { - ExprEval eval(expr, {var}); - return eval.value(value); -} - -TEST(LoopNest, ExprSliceWithVariableDimension) { - auto testWithDimension = - [](int dimension, - const std::vector>& expected_for_ranges) { - VarHandle dim("dim", kInt); - Tensor tensor = - Compute("f", {dim}, [](const ExprHandle& x) { return x; }); - LoopNest l({tensor}); - std::vector loops = - l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - - ForPtr head; - ForPtr tail; - LoopNest::sliceHead(loops[0], 2, &head, &tail); - - LoopNest::sliceTail(tail, 2); - - BlockPtr body = getSimplifiedBody(l); - ASSERT_EQ(expected_for_ranges.size(), 3); - auto it = body->begin(); - for (auto& start_stop : expected_for_ranges) { - ForPtr loop = to(*it++); - int start = evalExpr(ExprHandle(loop->start()), dim, dimension); - int stop = evalExpr(ExprHandle(loop->stop()), dim, dimension); - ASSERT_EQ(start, start_stop.first); - ASSERT_EQ(stop, start_stop.second); - } - }; - - testWithDimension(1, {{0, 1}, {1, 1}, {1, 1}}); - testWithDimension(2, {{0, 2}, {2, 2}, {2, 2}}); - testWithDimension(3, {{0, 2}, {2, 2}, {2, 3}}); - testWithDimension(4, {{0, 2}, {2, 2}, {2, 4}}); - testWithDimension(5, {{0, 2}, {2, 3}, {3, 5}}); - testWithDimension(10, {{0, 2}, {2, 8}, {8, 10}}); -} - -TEST(LoopNest, ExprSplitWithTail) { - auto func = [](const ExprHandle& x) { - return ExprHandle(1.0f) + cast(x); - }; - Tensor tensor = Compute("f", {199}, func); - LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - LoopNest::splitWithTail(loops[0], 17); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - LoopNest::splitWithTail(loops[0], 7); - - StmtPtr stmt = l.root_stmt(); - StmtPtr simplified = IRSimplifier::simplify(stmt); - BlockPtr body = to(simplified); - ASSERT_EQ(body->nstmts(), 3); - auto biter = body->begin(); - - // Verify that the split loops are ordered correctly. - ForPtr loop = to(*biter++); - assertForRange(loop, 0, 7); - - loop = to(*biter++); - assertForRange(loop, 0, 4); - - loop = to(*biter); - assertForRange(loop, 0, 12); -} - -TEST(LoopNest, ExprSplitWithTailNone) { - auto func = [](const ExprHandle& x, const ExprHandle& y) { - return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; - }; - Tensor tensor = Compute("f", {24, 5}, func); - LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::splitWithTail(loops[0], 4); - - StmtPtr stmt = l.root_stmt(); - std::ostringstream oss; - oss << *stmt; - ASSERT_GT(oss.str().size(), 200); - ASSERT_LT(oss.str().size(), 600); - - { - // Compare to a reference loop structure structure. - VarHandle x_outer("i_outer", kInt); - VarHandle x_inner("i_inner", kInt); - VarHandle y("i", kInt); - VarHandle x_tail("i_tail", kInt); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers) - BufHandle f("f", {24, 5}, kFloat); - ExprHandle x_1 = x_outer * 4 + x_inner; - ExprHandle x_outer_end = (ExprHandle(24) - 0) / 4; - StmtPtr stmt = alloc(std::vector({For::make( - x_outer, - 0, - x_outer_end, - For::make( - x_inner, - 0, - 4, - For::make(y, 0, 5, Store::make(f, {x_1, y}, func(x_1, y)))))})); - - std::ostringstream oss_ref; - oss_ref << *stmt; - ASSERT_EQ(oss.str(), oss_ref.str()); - } - - { - PaddedBuffer f_v(24, 5, "f_v"); - PaddedBuffer f_ref(24, 5, "f_res"); - - SimpleIREvaluator ir_eval(stmt, {tensor}); - ir_eval(f_v); - - for (int x = 0; x < 24; x++) { - for (int y = 0; y < 5; y++) { - f_ref(x, y) = 1 + x * x + y * y; - } - } - - ExpectAllNear(f_v, f_ref, 1e-5); - } -} - -TEST(LoopNest, ExprSplitWithMask01) { - const int M = 26; - const int N = 5; - BufHandle a_buf("a", {M, N}, kFloat); - BufHandle b_buf("b", {M, N}, kFloat); - Tensor tensor = - Compute("f", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return a_buf.load(m, n) + b_buf.load(m, n) + 1.0f; - }); - - LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::splitWithMask(loops[1], 4); - - StmtPtr stmt = l.root_stmt(); - - PaddedBuffer a_v(M, N, "a"); - PaddedBuffer b_v(M, N, "b"); - PaddedBuffer c_v(M, N, "c"); - PaddedBuffer c_ref(M, N, "c_ref"); - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - a_v(m, n) = 2 * m; - b_v(m, n) = 3 * n; - c_ref(m, n) = a_v(m, n) + b_v(m, n) + 1.0f; - } - } - - SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v); - - ExpectAllNear(c_v, c_ref, 1e-5); -} - -// Tests the case where we split a loop cleanly multiple times, we should not -// insert any masks. -TEST(LoopNest, ExprSplitWithMaskRepeatedNoMask) { - const int M = 64; - BufHandle a_buf("a", {M}, kFloat); - BufHandle b_buf("b", {M}, kFloat); - Tensor tensor = Compute("f", {M}, [&](const ExprHandle& m) { - return a_buf.load(m) + b_buf.load(m) + 1.0f; - }); - - LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::splitWithMask(loops[0], 4); - LoopNest::splitWithMask(loops[0], 4); - - StmtPtr stmt1 = IRSimplifier::simplify(l.root_stmt()); - - // Two splits mean 3 loops, but should need no masks in this case. - checkIR(stmt1, R"IR( -# CHECK: for ( -# CHECK-NOT: if ( -# CHECK: for ( -# CHECK-NOT: if ( -# CHECK: for ( -# CHECK-NOT: if ( -# CHECK: f[)IR"); -} - -TEST(LoopNest, getLoopAt) { - // Input IR: - // for (int i = 0; i < 100; i++) { - // for (int j = 0; j < 100; j++) { - // A[i, j] = sin(i * j); - // for (int k1 = 0; k1 < 200; k1++) { - // B[i, j, k1] = (A[i, j]) / (k1 + 1); - // } - // for (int k2 = 0; k2 < 300; k2++) { - // C[i, j, k2] = (A[i, j]) * (k2 + 1); - // } - // } - // } - BufPtr A = alloc( - "A", - std::vector({alloc(100), alloc(100)}), - kInt); - BufPtr B = alloc( - "B", - std::vector( - {alloc(100), alloc(100), alloc(200)}), - kInt); - BufPtr C = alloc( - "C", - std::vector( - {alloc(100), alloc(100), alloc(300)}), - kInt); - BufHandle a_buf(A); - BufHandle b_buf(B); - BufHandle c_buf(C); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k1("k1", kInt); - VarHandle k2("k2", kInt); - auto store1 = Store::make(a_buf, {i, j}, sin(i * j)); - auto store2 = Store::make( - b_buf, {i, j, k1}, Div::make(Load::make(a_buf, {i, j}), (k1 + 1))); - auto store3 = Store::make( - c_buf, {i, j, k2}, Mul::make(Load::make(a_buf, {i, j}), (k2 + 1))); - auto for_k2 = For::make(k2, 0, 300, Block::make({store3})); - auto for_k1 = For::make(k1, 0, 200, Block::make({store2})); - auto for_j = For::make(j, 0, 100, Block::make({store1, for_k1, for_k2})); - auto for_i = For::make(i, 0, 100, for_j); - LoopNest l(Block::make({for_i}), {B, C}); - auto ret_k2 = l.getLoopAt(for_i, {0, 2}); - TORCH_CHECK(ret_k2 == for_k2); - - std::ostringstream oss; - oss << *ret_k2; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int k2 -# CHECK-NEXT: C[i, j, k2] = - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(LoopNest, TileSimple) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - const int M = 64, N = 64; - BufHandle a_buf("a", {M, N}, kFloat); - BufHandle b_buf("b", {M, N}, kFloat); - Tensor tensor = - Compute("f", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return a_buf.load({m, n}) + b_buf.load({m, n}) + 1.0f; - }); - - LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - l.tile(loops[0], loops[1], 4, 8); - - // IR check - StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); - checkIR(stmt, R"IR( -# CHECK: for (int i_outer -# CHECK: for (int i_outer_1 -# CHECK: for (int i_inner -# CHECK: for (int i_inner_1 -# CHECK: f[ -# CHECK-NOT: for (int i_tail -# CHECK-NOT: for (int i_tail)IR"); - - // Correctness check - PaddedBuffer a_v(M, N, "a"); - PaddedBuffer b_v(M, N, "b"); - PaddedBuffer c_v(M, N, "c"); - PaddedBuffer c_ref(M, N, "c_ref"); - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - a_v(m, n) = 2 * m; - b_v(m, n) = 3 * n; - c_ref(m, n) = a_v(m, n) + b_v(m, n) + 1.0f; - } - } - - SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v); - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - ExpectAllNear(c_v, c_ref, 1e-5); -} - -TEST(LoopNest, TileWithTails) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - const int M = 64, N = 64; - BufHandle a_buf("a", {M, N}, kFloat); - BufHandle b_buf("b", {M, N}, kFloat); - Tensor tensor = - Compute("f", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return a_buf.load({m, n}) + b_buf.load({m, n}) + 1.0f; - }); - - LoopNest l({tensor}); - std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - l.tile(loops[0], loops[1], 5, 9); - - // IR check - StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); - checkIR(stmt, R"IR( -# CHECK: for (int i_outer -# CHECK: for (int i_outer_1 -# CHECK: for (int i_inner -# CHECK: for (int i_inner_1 -# CHECK: f[ -# CHECK: for (int i_inner -# CHECK: f[ -# CHECK: for (int i_tail)IR"); - - // Correctness check - PaddedBuffer a_v(M, N, "a"); - PaddedBuffer b_v(M, N, "b"); - PaddedBuffer c_v(M, N, "c"); - PaddedBuffer c_ref(M, N, "c_ref"); - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - a_v(m, n) = 2 * m; - b_v(m, n) = 3 * n; - c_ref(m, n) = a_v(m, n) + b_v(m, n) + 1.0f; - } - } - - SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v); - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - ExpectAllNear(c_v, c_ref, 1e-5); -} - -TEST(LoopNest, TileInMiddle) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - const int M = 8, N = 8, L = 8, K = 8; - BufHandle a_buf("a", {M, N, L, K}, kFloat); - BufHandle b_buf("b", {M, N, L, K}, kFloat); - Tensor tensor = Compute( - "f", - {M, N, L, K}, - [&](const ExprHandle& m, - const ExprHandle& n, - const ExprHandle& l, - const ExprHandle& k) { - return a_buf.load({m, n, l, k}) + b_buf.load({m, n, l, k}) + 1.0f; - }); - - LoopNest nest({tensor}); - std::vector loops = - nest.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - nest.tile(loops[1], loops[2], 3, 3); - - // IR check - StmtPtr stmt = IRSimplifier::simplify(nest.root_stmt()); - checkIR(stmt, R"IR( -# CHECK: for (int i -# CHECK: for (int i_outer -# CHECK: for (int i_outer_1 -# CHECK: for (int i_inner -# CHECK: for (int i_inner_1 -# CHECK: for (int i_1 -# CHECK: f[ -# CHECK: for (int i_tail_1 -# CHECK: for (int i_inner_1 -# CHECK: for (int i_1 -# CHECK: f[ -# CHECK: for (int i_tail)IR"); - - // Correctness check - PaddedBuffer a_v(M, N, L, K, "a"); - PaddedBuffer b_v(M, N, L, K, "b"); - PaddedBuffer c_v(M, N, L, K, "c"); - PaddedBuffer c_ref(M, N, L, K, "c_ref"); - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - for (int l = 0; l < L; l++) { - for (int k = 0; k < K; k++) { - a_v(m, n, l, k) = 2 * (m + l); - b_v(m, n, l, k) = 3 * (n + k); - c_ref(m, n, l, k) = a_v(m, n, l, k) + b_v(m, n, l, k) + 1.0f; - } - } - } - } - - SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v); - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - ExpectAllNear(c_v, c_ref, 1e-5); -} - -TEST(LoopNest, SplitWithTailWithLoopOptions) { - const int M = 21; - BufHandle a_buf("a", {M}, kFloat); - BufHandle b_buf("b", {M}, kFloat); - Tensor tensor = Compute("f", {M}, [&](const ExprHandle& m) { - return a_buf.load(m) + b_buf.load(m) + 1.0f; - }); - ForPtr inner, tail; - - LoopNest l({tensor}); - auto loops = NodeFinder::find(l.root_stmt()); - ASSERT_GT(loops.size(), 0); - loops[0]->set_gpu_block_index(LoopOptions::IDX_Y); - LoopNest::splitWithTail(loops[0], 4, &inner, &tail); - ASSERT_NE(inner, nullptr); - ASSERT_NE(tail, nullptr); - ForPtr outer = loops[0]; - - // Outer loop carries loop axis bindings. - ASSERT_TRUE(outer->loop_options().is_gpu_block_index()); - ASSERT_EQ(outer->loop_options().gpu_block_index(), LoopOptions::IDX_Y); - - // Inner loop has none. - ASSERT_TRUE(inner->loop_options().isDefault()); - - // Tail loop has none. - ASSERT_TRUE(tail->loop_options().isDefault()); -} - -TEST(LoopNest, SplitWithMaskWithLoopOptions) { - const int M = 21; - BufHandle a_buf("a", {M}, kFloat); - BufHandle b_buf("b", {M}, kFloat); - Tensor tensor = Compute("f", {M}, [&](const ExprHandle& m) { - return a_buf.load(m) + b_buf.load(m) + 1.0f; - }); - ForPtr inner; - - LoopNest l({tensor}); - auto loops = NodeFinder::find(l.root_stmt()); - loops[0]->set_gpu_block_index(LoopOptions::IDX_Y); - LoopNest::splitWithMask(loops[0], 4, &inner); - ForPtr outer = loops[0]; - - // Outer loop carries loop axis bindings. - ASSERT_TRUE(outer->loop_options().is_gpu_block_index()); - ASSERT_EQ(outer->loop_options().gpu_block_index(), LoopOptions::IDX_Y); - - // Inner loop has none. - ASSERT_TRUE(inner->loop_options().isDefault()); -} - -TEST(LoopNest, ScheduleBroadcastAddBuffer) { - const int M = 4; - const int N = 5; - const int K = 6; - BufHandle a_buf("a", {M, N}, kFloat); - BufHandle b_buf("b", {N, K}, kFloat); - Tensor c = Compute( - "broadcast_add", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) + b_buf.load(n, k); - }); - LoopNest l({c}); - StmtPtr stmt = l.root_stmt(); - - PaddedBuffer a_v(M, N, "a_v"); - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - a_v(m, n) = 7 * m * n; - } - } - a_v.Backup(); - - PaddedBuffer b_v(N, K, "b_v"); - for (int n = 0; n < N; n++) { - for (int k = 0; k < K; k++) { - b_v(n, k) = 11 * n * k; - } - } - b_v.Backup(); - - PaddedBuffer c_v(M, N, K, "c_buf"); - SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c}); - ir_eval(a_v, b_v, c_v); - - a_v.CheckBackup(); - b_v.CheckBackup(); - PaddedBuffer c_ref(M, N, K, "c_ref"); - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - for (int k = 0; k < K; k++) { - c_ref(m, n, k) = 7 * m * n + 11 * n * k; - } - } - } - ExpectAllNear(c_v, c_ref, 1e-5); -} - -TEST(LoopNest, ScheduleFunctionCall01) { - const int M = 4; - const int N = 5; - const int K = 6; - BufHandle a_buf("a", {M, N}, kFloat); - BufHandle b_buf("b", {N, K}, kFloat); - Tensor c = Compute( - "broadcast_add", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) + b_buf.load(n, k); - }); - Tensor d = Compute( - "d", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return c.load(m, n, k) + 1; - }); - - LoopNest l({d}, {c, d}); - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - std::ostringstream oss; - oss << *stmt; - ASSERT_GT(oss.str().size(), 100); - - PaddedBuffer a_v(M, N); - PaddedBuffer b_v(N, K); - PaddedBuffer c_v(M, N, K); - PaddedBuffer d_v(M, N, K); - PaddedBuffer d_ref(M, N, K); - - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - a_v(i, j) = i * i; - } - } - for (int i = 0; i < N; i++) { - for (int j = 0; j < K; j++) { - b_v(i, j) = j * j; - } - } - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - for (int k = 0; k < K; k++) { - d_ref(i, j, k) = a_v(i, j) + b_v(j, k) + 1; - } - } - } - - SimpleIREvaluator eval(stmt, {a_buf, b_buf, d}); - eval(a_v, b_v, d_v); - - ExpectAllNear(d_v, d_ref, 1e-5); -} - -TEST(LoopNest, ScheduleInlineSimple) { - const int M = 4; - const int N = 5; - const int K = 6; - BufHandle a_buf("a", {M, N}, kFloat); - BufHandle b_buf("b", {N, K}, kFloat); - BufHandle c_buf("c", {M, N}, kFloat); - BufHandle d_buf("d", {M, K}, kFloat); - - Tensor x = Compute( - "x", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) * b_buf.load(n, k); - }); - Tensor y = Compute( - "y", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k); - }); - - LoopNest l1({y}, {x, y}); - LoopNest l2(l1); - l2.computeInline(x.buf()); - - l1.prepareForCodegen(); - l2.prepareForCodegen(); - - StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); - StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt()); - - SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, c_buf, d_buf, y}); - SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, c_buf, d_buf, y}); - - PaddedBuffer a_v(M, N); - PaddedBuffer b_v(N, K); - PaddedBuffer c_v(M, N); - PaddedBuffer d_v(M, K); - - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - a_v(i, j) = i * i; - } - } - for (int i = 0; i < N; i++) { - for (int j = 0; j < K; j++) { - b_v(i, j) = j * j; - } - } - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - c_v(i, j) = i + j; - } - } - for (int i = 0; i < M; i++) { - for (int j = 0; j < K; j++) { - d_v(i, j) = i * j; - } - } - - PaddedBuffer y_1(M, N, K); - PaddedBuffer y_2(M, N, K); - - eval1(a_v, b_v, c_v, d_v, y_1); - eval2(a_v, b_v, c_v, d_v, y_2); - ExpectAllNear(y_1, y_2, 1e-5); - std::ostringstream oss1, oss2; - oss1 << *stmt1; - oss2 << *stmt2; - ASSERT_GT(oss1.str().size(), oss2.str().size()); -} - -static std::string remove_space(const std::string& str) { - std::string str_new = str; - str_new.erase( - remove_if(str_new.begin(), str_new.end(), isspace), str_new.end()); - return str_new; -} - -void InlineFunc01Helper(const std::vector& inline_order) { - const int M = 4; - const int N = 5; - const int K = 6; - BufHandle a_buf("a", {M, N}, kFloat); - BufHandle b_buf("b", {N, K}, kFloat); - BufHandle c_buf("c", {M, N}, kFloat); - BufHandle d_buf("d", {M, K}, kFloat); - - Tensor x = Compute( - "x", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) * b_buf.load(n, k); - }); - Tensor y = Compute( - "y", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k); - }); - Tensor z = Compute( - "z", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return x.load(m, n, k) + y.load(m, n, k); - }); - - LoopNest l({z}, {x, y, z}); - for (const std::string& order : inline_order) { - if (order == "x") { - l.computeInline(x.buf()); - } else if (order == "y") { - l.computeInline(y.buf()); - } else { - throw std::runtime_error("Invalid order: " + order); - } - } - l.prepareForCodegen(); - StmtPtr stmt = l.root_stmt(); - - std::ostringstream oss; - oss << *stmt; - std::string str1 = remove_space(oss.str()); - - { - PaddedBuffer a_v(M, N); - PaddedBuffer b_v(N, K); - PaddedBuffer c_v(M, N); - PaddedBuffer d_v(M, K); - - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - a_v(i, j) = i * i; - } - } - for (int i = 0; i < N; i++) { - for (int j = 0; j < K; j++) { - b_v(i, j) = j * j; - } - } - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - c_v(i, j) = i + j; - } - } - for (int i = 0; i < M; i++) { - for (int j = 0; j < K; j++) { - d_v(i, j) = i * j; - } - } - - PaddedBuffer z_v(M, N, K); - PaddedBuffer z_ref(M, N, K); - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - for (int k = 0; k < K; k++) { - z_ref(m, n, k) = a_v(m, n) * b_v(n, k) * 2 + c_v(m, n) * d_v(m, k); - } - } - } - - SimpleIREvaluator eval(stmt, {a_buf, b_buf, c_buf, d_buf, z}); - eval(a_v, b_v, c_v, d_v, z_v); - ExpectAllNear(z_v, z_ref, 1e-5); - } - - if (inline_order.size() == 2) { - Tensor z2 = Compute( - "z", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) * b_buf.load(n, k) + - (c_buf.load(m, n) * d_buf.load(m, k) + - a_buf.load(m, n) * b_buf.load(n, k)); - }); - LoopNest l2({z2}); - l2.prepareForCodegen(); - StmtPtr stmt2 = l2.root_stmt(); - - std::ostringstream oss2; - oss2 << *stmt2; - std::string str2 = remove_space(oss2.str()); - - ASSERT_EQ(str1, str2); - ASSERT_GT(str1.size(), 100); - } -} - -TEST(LoopNest, ScheduleInlineFunc01) { - InlineFunc01Helper({"x", "y"}); - InlineFunc01Helper({"y", "x"}); - InlineFunc01Helper({"x"}); - InlineFunc01Helper({"y"}); - InlineFunc01Helper({}); -} - -// Make sure we cache random vars if we should. -TEST(LoopNest, ScheduleInlineRandom) { - const int M = 4; - const int N = 5; - const int K = 6; - - Tensor x = Compute( - "x", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return Mod::make(Intrinsics::make(kRand, kInt), 5); - }); - Tensor y = Compute( - "y", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return x.load(m, n, k) + x.load(m, n, k); - }); - - LoopNest l1({y}, {x, y}); - l1.computeInline(x.buf()); - - // would normally compare results but Rand isn't implemented in the - // SimpleIREvaluator, even if we could seed it. - StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); - - // Check the IR we produced - checkIR(stmt1, R"IR( -# CHECK: for (int i = 0; i < 4; i++) -# CHECK: for (int i_1 = 0; i_1 < 5; i_1++) -# CHECK: for (int i_2 = 0; i_2 < 6; i_2++) -# CHECK: int x = rand(); -# CHECK: y[i, i_1, i_2] = 2 * (x % 5);)IR"); -} - -// Make sure we don't cache random vars that are not being inlined. -TEST(LoopNest, ScheduleInlineRandomUnrelated) { - const int M = 4; - const int N = 5; - const int K = 6; - - Tensor x = Compute( - "x", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return m * n * k; - }); - Tensor y = Compute( - "y", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return x.load(m, n, k) + Intrinsics::make(kRand, kInt) + - Intrinsics::make(kRand, kInt); - }); - - LoopNest l1({y}, {x, y}); - l1.computeInline(x.buf()); - - // would normally compare results but Rand isn't implemented in the - // SimpleIREvaluator, even if we could seed it. - StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); - - // Check the IR we produced - checkIR(stmt1, R"IR( -# CHECK: for (int i = 0; i < 4; i++) -# CHECK: for (int i_1 = 0; i_1 < 5; i_1++) -# CHECK: for (int i_2 = 0; i_2 < 6; i_2++) -# CHECK: y[i, i_1, i_2] = ((i * i_1) * i_2 + (rand())) + (rand());)IR"); -} - -// Make sure we generate the right number of random values == the dimensionality -// of the production tensor. -TEST(LoopNest, ScheduleInlineRandomLowerDimensions) { - const int M = 4; - const int N = 5; - const int K = 6; - - Tensor x = Compute("x", {M}, [&](const VarHandle& m) { - return Mod::make(Intrinsics::make(kRand, kInt), 5); - }); - Tensor y = Compute( - "y", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return x.load(m) + x.load(m); - }); - - LoopNest l1({y}, {x, y}); - l1.computeInline(x.buf()); - - // would normally compare results but Rand isn't implemented in the - // SimpleIREvaluator, even if we could seed it. - StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); - - // Check the IR we produced - checkIR(stmt1, R"IR( -# CHECK: for (int i = 0; i < 4; i++) -# CHECK: int x = rand(); -# CHECK: for (int i_1 = 0; i_1 < 5; i_1++) -# CHECK: for (int i_2 = 0; i_2 < 6; i_2++) -# CHECK: y[i, i_1, i_2] = 2 * (x % 5);)IR"); -} - -// Make sure we don't screw up intrinsics thinking they're rand. -TEST(LoopNest, ScheduleInlineIntrinsics) { - const int M = 4; - const int N = 5; - const int K = 6; - BufHandle a_buf("a", {M, N}, kFloat); - BufHandle b_buf("b", {N, K}, kFloat); - - Tensor x = Compute( - "x", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) * b_buf.load(n, k); - }); - Tensor y = Compute( - "y", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return Intrinsics::make(kSqrt, x.load(m, n, k)); - }); - - PaddedBuffer a_v(M, N); - PaddedBuffer b_v(N, K); - - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - a_v(i, j) = i * i; - } - } - for (int i = 0; i < N; i++) { - for (int j = 0; j < K; j++) { - b_v(i, j) = j * j; - } - } - - LoopNest l1({y}, {x, y}); - LoopNest l2(l1); - l2.computeInline(x.buf()); - - l1.prepareForCodegen(); - l2.prepareForCodegen(); - - StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); - StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt()); - - SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y}); - SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y}); - - PaddedBuffer y_1(M, N, K); - PaddedBuffer y_2(M, N, K); - - eval1(a_v, b_v, y_1); - eval2(a_v, b_v, y_2); - ExpectAllNear(y_1, y_2, 1e-5); - std::ostringstream oss1, oss2; - oss1 << *stmt1; - oss2 << *stmt2; - ASSERT_GT(oss1.str().size(), oss2.str().size()); -} - -// Make sure we can handle rand and non-rand intrinsics. -TEST(LoopNest, ScheduleInlineRandWithIntrinsics) { - const int M = 4; - const int N = 5; - const int K = 6; - - Tensor x = Compute( - "x", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return Intrinsics::make(kRand, kFloat); - }); - Tensor y = Compute( - "y", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return Intrinsics::make(kSqrt, x.load(m, n, k)); - }); - - LoopNest l1({y}, {x, y}); - l1.computeInline(x.buf()); - - StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); - - // Check the IR we produced - checkIR(stmt1, R"IR( -# CHECK: for (int i = 0; i < 4; i++) -# CHECK: for (int i_1 = 0; i_1 < 5; i_1++) -# CHECK: for (int i_2 = 0; i_2 < 6; i_2++) -# CHECK: float x = rand(); -# CHECK: y[i, i_1, i_2] = sqrt(x);)IR"); -} - -// Split a Compute then inline it into another compute. -TEST(LoopNest, ScheduleSplitAThenInline) { - Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); - Tensor b = Compute( - "b", {2}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); - - LoopNest l({b}, {a, b}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); - LoopNest::splitWithMask(loops[0], 4); - ASSERT_FALSE(l.computeInline(a.buf())); -} - -// Split a Compute then inline another Compute into it. -TEST(LoopNest, ScheduleSplitBThenInline) { - Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); - Tensor b = Compute( - "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); - - LoopNest l({b}, {a, b}); - std::vector loops = l.getAllLoopNestsWritingToBuf(b.buf()).at(0); - LoopNest::splitWithMask(loops[0], 3); - l.computeInline(a.buf()); - l.prepareForCodegen(); - StmtPtr s = IRSimplifier::simplify(l.root_stmt()); - - std::vector output(6, 0); - SimpleIREvaluator eval(s, {b}); - eval(output); - - for (int i = 0; i < 6; ++i) { - ASSERT_EQ(output[i], (i + 8) * (i + 8)); - } -} - -// Split a Compute twice then inline it. -TEST(LoopNest, ScheduleSplitTwiceThenInline) { - Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); - Tensor b = Compute( - "b", {2}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); - ForPtr i_inner; - - LoopNest l({b}, {a, b}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); - LoopNest::splitWithMask(loops[0], 4, &i_inner); - LoopNest::splitWithMask(i_inner, 2); - ASSERT_FALSE(l.computeInline(a.buf())); -} - -// Inline a Compute, then split. -TEST(LoopNest, ScheduleInlineThenSplit) { - Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); - Tensor b = Compute( - "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); - - LoopNest l({b}, {a, b}); - l.computeInline(a.buf()); - - std::vector loops = NodeFinder::find(l.root_stmt()); - LoopNest::splitWithMask(loops.back(), 3); - l.prepareForCodegen(); - StmtPtr s = IRSimplifier::simplify(l.root_stmt()); - std::vector output(6, 0); - SimpleIREvaluator eval(s, {b}); - eval(output); - - for (int i = 0; i < 6; ++i) { - ASSERT_EQ(output[i], (i + 8) * (i + 8)); - } -} - -// Split a Compute, inline it, then split the result. -TEST(LoopNest, ScheduleSplitInlineThenSplit) { - Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); - Tensor b = Compute( - "b", {16}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); - - LoopNest l({b}, {a, b}); - auto loops = NodeFinder::find(l.root_stmt()); - LoopNest::splitWithMask(loops.back(), 2); - l.computeInline(a.buf()); - - loops = NodeFinder::find(l.root_stmt()); - LoopNest::splitWithMask(loops.front(), 2); - l.prepareForCodegen(); - StmtPtr s = IRSimplifier::simplify(l.root_stmt()); - std::vector output(16, 0); - SimpleIREvaluator eval(s, {b}); - eval(output); - - for (int i = 0; i < 16; ++i) { - ASSERT_EQ(output[i], (i + 8) * (i + 8)); - } -} - -// Oversplit a loop that is simplified out after inlining. -TEST(LoopNest, ScheduleSplitInlineSimplify) { - Tensor a = Compute("a", {18}, [&](const VarHandle& i) { - return ExprHandle(4) * i - ExprHandle(2) * i; - }); - Tensor b = Compute( - "b", {2}, [&](const VarHandle& j) { return a.load(j) - ExprHandle(1); }); - - LoopNest l({b}, {a, b}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); - LoopNest::splitWithMask(loops[0], 4); - ASSERT_FALSE(l.computeInline(a.buf())); -} - -// Inline a Compute with two consumers. -TEST(LoopNest, ScheduleInlineThreeMixedOnce) { - Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); - Tensor b = Compute( - "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); - Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) { - return a.load(k) * b.load(l); - }); - - LoopNest l({c}, {a, b, c}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); - l.computeInline(a.buf()); - l.prepareForCodegen(); - - StmtPtr s = IRSimplifier::simplify(l.root_stmt()); - std::vector output(4 * 3, 0); - SimpleIREvaluator eval(s, {c}); - eval(output); - - for (int k = 0; k < 4; ++k) { - for (int l = 0; l < 3; ++l) { - ASSERT_EQ(output[k * 3 + l], (k) * (k) * (l + 8) * (l + 8)); - } - } -} - -// Inline Compute A into B, then inline B into C. -TEST(LoopNest, ScheduleInlineThreeMixedTwice) { - Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); - Tensor b = Compute( - "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); - Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) { - return a.load(k) * b.load(l); - }); - - LoopNest l({c}, {a, b, c}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); - l.computeInline(a.buf()); - l.computeInline(b.buf()); - l.prepareForCodegen(); - - StmtPtr s = IRSimplifier::simplify(l.root_stmt()); - std::vector output(4 * 3, 0); - SimpleIREvaluator eval(s, {c}); - eval(output); - - for (int k = 0; k < 4; ++k) { - for (int l = 0; l < 3; ++l) { - ASSERT_EQ(output[k * 3 + l], (k) * (k) * (l + 8) * (l + 8)); - } - } -} - -// Inline a Compute that is both a producer and consumer. -TEST(LoopNest, ScheduleInlineThreeMixedInner) { - Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); - Tensor b = Compute( - "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); - Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) { - return a.load(k) * b.load(l); - }); - - LoopNest l({c}, {a, b, c}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); - l.computeInline(b.buf()); - l.prepareForCodegen(); - - StmtPtr s = IRSimplifier::simplify(l.root_stmt()); - std::vector output(4 * 3, 0); - SimpleIREvaluator eval(s, {c}); - eval(output); - - for (int k = 0; k < 4; ++k) { - for (int l = 0; l < 3; ++l) { - ASSERT_EQ(output[k * 3 + l], (k) * (k) * (l + 8) * (l + 8)); - } - } -} - -// Split 3 Computes, then inline the first two into the last. -TEST(LoopNest, ScheduleInlineThreeMixedSplit) { - Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); - Tensor b = Compute( - "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); - Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) { - return a.load(k) * b.load(l); - }); - - LoopNest l({c}, {a, b, c}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); - LoopNest::splitWithMask(loops[0], 4); - loops = l.getAllLoopNestsWritingToBuf(b.buf()).at(0); - LoopNest::splitWithMask(loops[0], 3); - loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); - LoopNest::splitWithMask(loops[0], 2); - - ASSERT_FALSE(l.computeInline(a.buf())); -} - -// Check that inlining works for output tensors too -TEST(LoopNest, ScheduleInlineOutputTensors) { - const int M = 4; - const int N = 5; - const int K = 6; - - Tensor x = Compute( - "x", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return m * n * k; - }); - Tensor y = Compute( - "y", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return x.load(m, n, k) + m; - }); - - LoopNest l1({x, y}); - l1.computeInline(x.buf()); - - // would normally compare results but Rand isn't implemented in the - // SimpleIREvaluator, even if we could seed it. - StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); - - // Check the IR we produced - checkIR(stmt1, R"IR( -# CHECK: for (int i = 0; i < 4; i++) -# CHECK: for (int i_1 = 0; i_1 < 5; i_1++) -# CHECK: for (int i_2 = 0; i_2 < 6; i_2++) -# CHECK: x[i, i_1, i_2] = (i * i_1) * i_2; -# CHECK: for (int i_3 = 0; i_3 < 4; i_3++) -# CHECK: for (int i_4 = 0; i_4 < 5; i_4++) -# CHECK: for (int i_5 = 0; i_5 < 6; i_5++) -# CHECK: y[i_3, i_4, i_5] = i_3 + (i_3 * i_4) * i_5;)IR"); -} - -TEST(LoopNest, ScheduleInlineWithCompoundIndices) { - // Input IR: - // for (int64_t i = 0; i < 100; i++) { - // A[i*2,i] = i * 500ll; - // } - // for (int64_t j = 0; j < 100; j++) { - // B[0ll,j] = A[0, j] + j * 100ll; - // } - BufHandle a_buf("A", {20, 100}, kLong); - BufHandle b_buf("B", {20, 100}, kLong); - VarHandle i("i", kLong); - VarHandle j("j", kLong); - auto forI = For::make( - i, - 0, - 100, - Store::make(a_buf, {i * 2, i}, Mul::make(i, static_cast(500)))); - auto forJ = For::make( - j, - 0, - 100, - Store::make( - b_buf, - {static_cast(0), j}, - Add::make( - Load::make(a_buf, {static_cast(0), j}), - Mul::make(j, static_cast(100))))); - auto par = Block::make({forI, forJ}); - - LoopNest l(par, {b_buf.node()}); - // Inlining should fail since the producer has compound expr as index. - ASSERT_FALSE(l.computeInline(a_buf.node())); - - // The input statement must remain as is. - checkIR(l.root_stmt(), R"IR( - # CHECK: for (int64_t i = 0; - # CHECK-NEXT: A[ - # CHECK: for (int64_t j = 0; - # CHECK-NEXT: B[)IR"); -} - -TEST(LoopNest, ScheduleInlineConsumerIndicesWithCast) { - // Input IR: - // for (int64_t i = 0; i < 100; i++) { - // A[0ll,i] = i * 500ll; - // } - // for (int64_t j = 0; j < 100; j++) { - // B[0ll,j] = A[(int64_t)0, j] + j * 100ll; - // } - BufHandle a_buf("A", {20, 100}, kLong); - BufHandle b_buf("B", {20, 100}, kLong); - VarHandle i("i", kLong); - VarHandle j("j", kLong); - auto forI = For::make( - i, - 0, - 100, - Store::make( - a_buf, - {static_cast(0), i}, - Mul::make(i, static_cast(500)))); - auto forJ = For::make( - j, - 0, - 100, - Store::make( - b_buf, - {static_cast(0), j}, - Add::make( - Load::make(a_buf, {0, j}), - Mul::make(j, static_cast(100))))); - auto par = Block::make({forI, forJ}); - - LoopNest l(par, {b_buf.node()}); - ASSERT_TRUE(l.computeInline(a_buf.node())); - - checkIR(l.root_stmt(), R"IR( - # CHECK: for (int64_t j = 0; j < 100; j++) { - # CHECK: B[0ll, j] = j * 500ll + j * 100ll; - # CHECK: })IR"); -} - -TEST(LoopNest, ScheduleInlineProducerIndicesWithCast) { - // Input IR: - // for (int64_t i = 0; i < 100; i++) { - // A[(int64_t)0,i] = i * 500ll; - // } - // for (int64_t j = 0; j < 100; j++) { - // B[0ll,j] = A[0ll, j] + j * 100ll; - // } - BufHandle a_buf("A", {20, 100}, kLong); - BufHandle b_buf("B", {20, 100}, kLong); - VarHandle i("i", kLong); - VarHandle j("j", kLong); - auto forI = For::make( - i, - 0, - 100, - Store::make(a_buf, {0, i}, Mul::make(i, static_cast(500)))); - auto forJ = For::make( - j, - 0, - 100, - Store::make( - b_buf, - {static_cast(0), j}, - Add::make( - Load::make(a_buf, {static_cast(0), j}), - Mul::make(j, static_cast(100))))); - auto par = Block::make({forI, forJ}); - - LoopNest l(par, {b_buf.node()}); - ASSERT_TRUE(l.computeInline(a_buf.node())); - - checkIR(l.root_stmt(), R"IR( - # CHECK: for (int64_t j = 0; j < 100; j++) { - # CHECK: B[0ll, j] = j * 500ll + j * 100ll; - # CHECK: })IR"); -} - -TEST(LoopNest, ScheduleFuserStyle) { - const int kVectorSize = 8; - const int kVectorCount = 128; - const int kTotalSize = kVectorSize * kVectorCount; - - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - - Tensor b = - Compute("f", {kTotalSize}, [&](const std::vector& axes) { - return a_buf.load(axes[0]) + 11.0f; - }); - - Tensor c = - Compute("g", {kTotalSize}, [&](const std::vector& axes) { - return b.load(axes[0]) + 1.0f; - }); - - LoopNest l({b, c}); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - - std::vector a_data(kTotalSize, 7.0f); - std::vector b_data(kTotalSize, 0.0f); - std::vector c_data(kTotalSize, 0.0f); - SimpleIREvaluator(s, {a_buf, b, c})(a_data, b_data, c_data); - - for (int i = 0; i < kTotalSize; i++) { - ASSERT_EQ(b_data[i], 18.0f); - ASSERT_EQ(c_data[i], 19.0f); - } -} - -TEST(LoopNest, ScheduleFuserThreeArg) { - const int kVectorSize = 8; - const int kVectorCount = 128; - const int kTotalSize = kVectorSize * kVectorCount; - - BufHandle a("A", {ExprHandle(kTotalSize)}, kFloat); - BufHandle b("B", {ExprHandle(kTotalSize)}, kFloat); - BufHandle c("C", {ExprHandle(kTotalSize)}, kFloat); - BufHandle d("D", {ExprHandle(kTotalSize)}, kFloat); - - Tensor e = Compute("e", {kTotalSize}, [&](const VarHandle& i) { - return a.load(i) + b.load(i); - }); - Tensor f = Compute("f", {kTotalSize}, [&](const VarHandle& i) { - return e.load(i) + c.load(i); - }); - Tensor g = Compute("g", {kTotalSize}, [&](const VarHandle& i) { - return f.load(i) + d.load(i); - }); - - LoopNest l({g}, {e, f, g}); - l.computeInline(l.getLoopBodyFor(e)); - l.computeInline(l.getLoopBodyFor(f)); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - - std::vector a_data(kTotalSize, 1.0f); - std::vector b_data(kTotalSize, 2.0f); - std::vector c_data(kTotalSize, 3.0f); - std::vector d_data(kTotalSize, 4.0f); - std::vector g_data(kTotalSize, 0.0f); - SimpleIREvaluator(s, {a, b, c, d, g})(a_data, b_data, c_data, d_data, g_data); - - for (int i = 0; i < kTotalSize; i++) { - ASSERT_EQ(g_data[i], 10.0f); - } -} - -TEST(LoopNest, ScheduleDynamicShape2D) { - auto testWithSize = [](int32_t M, int32_t N) { - VarHandle m("m", kInt); - VarHandle n("n", kInt); - BufHandle a("a", {m, n}, kFloat); - BufHandle b("b", {m, n}, kFloat); - Tensor c = - Compute("c", {m, n}, [&](const VarHandle& i, const VarHandle& j) { - return a.load(i, j) + b.load(i, j); - }); - LoopNest l({c}); - StmtPtr s = l.root_stmt(); - SimpleIREvaluator cg(s, {a, b, c, m, n}); - std::vector aData(M * N, 1.0f); - std::vector bData(M * N, 2.0f); - std::vector cData(M * N, 0.0f); - cg.call({aData, bData, cData, M, N}); - ExpectAllNear(cData, std::vector(M * N, 3.0f), 1e-7); - }; - testWithSize(1, 8); - testWithSize(16, 32); - testWithSize(37, 11); -} - -TEST(LoopNest, LoopNestComputeAt_1) { - // Verify that compute_at works on the following example: - // - // for (int i_a = 0; i_a < N; i_a++) { - // A[i_a] = i_a * i_a - // } - // for (int i_b = 0; i_b < N; i_b++) { - // B[i_b] = A[i_b] - // } - // - // After the transformation the i_b loop should have an allocation for a temp - // buffer and that buffer should be used in computation of B. No use of A - // should be in that loop after the transformation. Also, computation of A - // should not be inlined into B. Instead, it should be computed into the temp, - // and the temp should be used in B. - VarHandle N("N", kInt); - Tensor A = Compute("A", {N}, [&](const VarHandle& i_a) { return i_a * i_a; }); - Tensor B = - Compute("B", {N}, [&](const VarHandle& i_b) { return A.load(i_b); }); - LoopNest l({B}, {A, B}); - std::vector loops = l.getAllLoopNestsWritingToBuf(B.buf()).at(0); - LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]); - l.prepareForCodegen(); - SimpleIREvaluator cg(l.root_stmt(), {B, N}); - StmtPtr s = cg.stmt(); - - checkIR(s, R"IR( -# CHECK: Allocate(temp); // dtype=int, dims=[1] -# CHECK: for (int i = 0; i < N; i++) -# CHECK: temp[ -# CHECK-NOT: A[ -# CHECK: B[i_1] = temp[0] -# CHECK: Free(temp))IR"); - - // Now check that the loop still produces the correct result. - std::vector b_data(100, 0); - cg.call({b_data, 100}); - - std::vector b_ref(100, 0); - for (int i = 0; i < 100; i++) { - b_ref[i] = i * i; - } - assertAllEqual(b_data, b_ref); -} - -TEST(LoopNest, LoopNestComputeAt_2) { - // Verify that compute_at works on the following example: - // - // for (int py = 0; py < H+1; py++) { - // for (int px = 0; px < W+1; px++) { - // p[py, px] = py*px - // } - // } - // for (int cy = 0; cy < H; cy++) { - // for (int cx = 0; cx < W; cx++) { - // c[py, px] = p[cy,cx] + p[cy+1,cx] + - // p[cy,cx+1] + p[cy+1,cx+1] - // } - // } - - const int kW = 16, kH = 16; - VarHandle W("W", kInt); - VarHandle H("H", kInt); - Tensor p = Compute( - "prod", {H + 1, W + 1}, [&](const VarHandle& py, const VarHandle& px) { - return px * py; - }); - Tensor c = - Compute("cons", {H, W}, [&](const VarHandle& y, const VarHandle& x) { - return p.load(y, x) + p.load(y + 1, x) + p.load(y, x + 1) + - p.load(y + 1, x + 1); - }); - - std::vector c_ref(kW * kH, 0); - for (int y = 0; y < kH; y++) { - for (int x = 0; x < kW; x++) { - c_ref[y * kW + x] = y * x + (y + 1) * x + y * (x + 1) + (y + 1) * (x + 1); - } - } - LoopNest orig_loopnest({c}, {p, c}); - - { - // First let's try to compute P at axis cy (the outer loop) - LoopNest l(orig_loopnest); - std::vector loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); - LoopNest::computeAt(l.getLoopBodyFor(p), loops[0]); - l.prepareForCodegen(); - SimpleIREvaluator cg(l.root_stmt(), {c, W, H}); - StmtPtr s = cg.stmt(); - - // Check the IR we produced - checkIR(s, R"IR( -# CHECK: Allocate(temp); // dtype=int, dims=[2, W + 1] -# CHECK: for (int i_2 = 0; i_2 < H; i_2++) -# CHECK: for -# CHECK: for -# CHECK: for (int i_3 = 0; i_3 < W; i_3++) -# CHECK-NOT: prod[ -# CHECK: cons[ -# CHECK: Free(temp))IR"); - - // Now check that the loop still produces the correct result. - std::vector c_data(kW * kH, 0); - cg.call({c_data, kW, kH}); - - assertAllEqual(c_data, c_ref); - } - { - // Now let's try to compute P at axis cx (the inner loop) - LoopNest l(orig_loopnest); - std::vector loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); - LoopNest::computeAt(l.getLoopBodyFor(p), loops[1]); - l.prepareForCodegen(); - SimpleIREvaluator cg(l.root_stmt(), {c, W, H}); - StmtPtr s = cg.stmt(); - - // Check the IR we produced - checkIR(s, R"IR( -# CHECK: Allocate(temp); // dtype=int, dims=[2, 2] -# CHECK: for (int i_2 = 0; i_2 < H; i_2++) -# CHECK: for (int i_3 = 0; i_3 < W; i_3++) -# CHECK: for -# CHECK: for -# CHECK-NOT: prod[ -# CHECK: cons[ -# CHECK: Free(temp))IR"); - - // Now check that the loop still produces the correct result. - std::vector c_data(kW * kH, 0); - cg.call({c_data, kW, kH}); - - assertAllEqual(c_data, c_ref); - } -} - -TEST(LoopNest, LoopNestComputeAt_3) { - // Verify that compute_at works on the following example: - // - // A(x,y) = x*y - // B(x,y) = A(x, y) - // C(x,y) = B(x+1, y) - // D(x,y) = A(x, y+1) + C(x, y) - // - // i.e. when 'A' comes to 'D' directly and indirectly through 'C'. - - const int kW = 16, kH = 16; - VarHandle W("W", kInt); - VarHandle H("H", kInt); - Tensor A = Compute( - "A", {H + 1, W + 1}, [&](const VarHandle& ay, const VarHandle& ax) { - return ax * ay; - }); - Tensor B = Compute( - "B", {H + 1, W + 1}, [&](const VarHandle& by, const VarHandle& bx) { - return A.load(by, bx); - }); - Tensor C = - Compute("C", {H, W}, [&](const VarHandle& cy, const VarHandle& cx) { - return B.load(cy, cx + 1); - }); - Tensor D = - Compute("D", {H, W}, [&](const VarHandle& dy, const VarHandle& dx) { - return A.load(dy + 1, dx) + C.load(dy, dx); - }); - - std::vector c_ref(kW * kH, 0); - for (int y = 0; y < kH; y++) { - for (int x = 0; x < kW; x++) { - c_ref[y * kW + x] = (y + 1) * x + y * (x + 1); - } - } - - LoopNest orig_loopnest({D}, {A, B, C, D}); - { - // First let's try to compute A at axis dy (the outer loop) - LoopNest l(orig_loopnest); - std::vector loops = l.getAllLoopNestsWritingToBuf(D.buf()).at(0); - LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]); - l.prepareForCodegen(); - SimpleIREvaluator cg(l.root_stmt(), {D, W, H}); - StmtPtr s = cg.stmt(); - - // Check the IR we produced - checkIR(s, R"IR( -# CHECK: Allocate(temp); // dtype=int, dims=[1, W] -# CHECK: for (int i = 0; i < H + 1; i++) -# CHECK: for (int i_1 = 0; i_1 < W + 1; i_1++) -# CHECK: A[ -# CHECK: for (int i_2 = 0; i_2 < H + 1; i_2++) -# CHECK: for (int i_3 = 0; i_3 < W + 1; i_3++) -# CHECK: B[ -# CHECK: for (int i_4 = 0; i_4 < H; i_4++) -# CHECK: for (int i_5 = 0; i_5 < W; i_5++) -# CHECK: C[ -# CHECK: for (int i_6 = 0; i_6 < H; i_6++) -# CHECK: for (int i_7 = 0; i_7 < W; i_7++) -# CHECK-NOT: A[)IR"); - - // Now check that the loop still produces the correct result. - std::vector c_data(kW * kH, 0); - cg.call({c_data, kW, kH}); - - assertAllEqual(c_data, c_ref); - } - { - // Now let's try to compute A at axis dx (the inner loop) - LoopNest l(orig_loopnest); - std::vector loops = l.getAllLoopNestsWritingToBuf(D.buf()).at(0); - LoopNest::computeAt(l.getLoopBodyFor(A), loops[1]); - l.prepareForCodegen(); - SimpleIREvaluator cg(l.root_stmt(), {D, W, H}); - StmtPtr s = cg.stmt(); - - // Check the IR we produced - checkIR(s, R"IR( -# CHECK: Allocate(temp); // dtype=int, dims=[1, 1] -# CHECK: for (int i = 0; i < H + 1; i++) -# CHECK: for (int i_1 = 0; i_1 < W + 1; i_1++) -# CHECK: A[ -# CHECK: for (int i_2 = 0; i_2 < H + 1; i_2++) -# CHECK: for (int i_3 = 0; i_3 < W + 1; i_3++) -# CHECK: B[ -# CHECK: for (int i_4 = 0; i_4 < H; i_4++) -# CHECK: for (int i_5 = 0; i_5 < W; i_5++) -# CHECK: C[ -# CHECK: for (int i_6 = 0; i_6 < H; i_6++) -# CHECK: for (int i_7 = 0; i_7 < W; i_7++) -# CHECK-NOT: A[)IR"); - - // Now check that the loop still produces the correct result. - std::vector c_data(kW * kH, 0); - cg.call({c_data, kW, kH}); - - assertAllEqual(c_data, c_ref); - } -} - -using Axis = const VarHandle&; - -TEST(LoopNest, Reduce2dComputeAt) { - const int kW = 16, kH = 16; - VarHandle W("W", kInt); - VarHandle H("H", kInt); - - Tensor p = Compute( - "prod", {H + 1, W + 1}, [&](Axis py, Axis px) { return px * py; }); - Tensor c = Reduce( - "cons", - {H, W}, - Sum(), - [&](Axis y, Axis x, Axis r, Axis s) { return p.load(y + r, x + s); }, - {2, 2}); - - std::vector c_ref(kW * kH, 0); - for (int y = 0; y < kH; y++) { - for (int x = 0; x < kW; x++) { - c_ref[y * kW + x] = y * x + (y + 1) * x + y * (x + 1) + (y + 1) * (x + 1); - } - } - LoopNest orig_loopnest({c}, {p, c}); - checkIR(orig_loopnest.root_stmt(), R"IR( -# CHECK: for (int i = 0; i < H + 1; i++) { -# CHECK: for (int i_1 = 0; i_1 < W + 1; i_1++) { -# CHECK: prod[i, i_1] = i_1 * i; -# CHECK: } -# CHECK: } -# CHECK: for (int i_2 = 0; i_2 < H; i_2++) { -# CHECK: for (int i_3 = 0; i_3 < W; i_3++) { -# CHECK: cons[i_2, i_3] = int(0); -# CHECK: for (int i_4 = 0; i_4 < 2; i_4++) { -# CHECK: for (int i_5 = 0; i_5 < 2; i_5++) { -# CHECK: cons[i_2, i_3] = ReduceOp((cons[i_2, i_3]) + (prod[i_2 + i_4, i_3 + i_5]), reduce_args={i_4, i_5}); -# CHECK: } -# CHECK: } -# CHECK: } -# CHECK: } -)IR"); - - { - // First let's try to compute P at axis cy (the outer loop) - LoopNest l(orig_loopnest); - auto loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); - LoopNest::computeAt(l.getLoopBodyFor(p), loops[0]); - // FIXME: Calling simplify here breaks the IR: - // MALFORMED INPUT: could not find base node in Load - temp[...] - // l.simplify(); - l.eliminateDeadStores(); - l.prepareForCodegen(); - SimpleIREvaluator cg(l.root_stmt(), {c, W, H}); - checkIR(cg.stmt(), R"IR( -# CHECK: Allocate(temp); // dtype=int, dims=[2, W + 1] -# CHECK: for (int i = 0; i < H; i++) { -# CHECK: for (int idx0 = 0; idx0 < 2; idx0++) { -# CHECK: for (int idx1 = 0; idx1 < W + 1; idx1++) { -# CHECK: temp[(0 + idx0 * (1 * (W + 1))) + idx1 * 1] = (idx0 + i) * (idx1 + 0); -# CHECK: } -# CHECK: } -# CHECK: for (int i_1 = 0; i_1 < W; i_1++) { -# CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = int(0); -# CHECK: for (int i_2 = 0; i_2 < 2; i_2++) { -# CHECK: for (int i_3 = 0; i_3 < 2; i_3++) { -# CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = (cons[(0 + i * (1 * W)) + i_1 * 1]) + (temp[(0 + i_2 * (1 * (W + 1))) + (i_1 + i_3) * 1]); -# CHECK: } -# CHECK: } -# CHECK: } -# CHECK: } -# CHECK: Free(temp); -)IR"); - - // Now check that the loop still produces the correct result. - std::vector c_data(kW * kH, 0); - cg.call({c_data, kW, kH}); - assertAllEqual(c_data, c_ref); - } - { - // Now let's try to compute P at axis cx (the inner loop) - LoopNest l(orig_loopnest); - std::vector loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); - LoopNest::computeAt(l.getLoopBodyFor(p), loops[1]); - l.simplify(); - l.eliminateDeadStores(); - l.prepareForCodegen(); - SimpleIREvaluator cg(l.root_stmt(), {c, W, H}); - checkIR(cg.stmt(), R"IR( -# CHECK: Allocate(temp); // dtype=int, dims=[2, 2] -# CHECK: for (int i = 0; i < H; i++) { -# CHECK: for (int i_1 = 0; i_1 < W; i_1++) { -# CHECK: for (int idx0 = 0; idx0 < 2; idx0++) { -# CHECK: for (int idx1 = 0; idx1 < 2; idx1++) { -# CHECK: temp[(0 + idx0 * (1 * 2)) + idx1 * 1] = (i + idx0) * (i_1 + idx1); -# CHECK: } -# CHECK: } -# CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = 0; -# CHECK: for (int i_2 = 0; i_2 < 2; i_2++) { -# CHECK: for (int i_3 = 0; i_3 < 2; i_3++) { -# CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = (cons[(0 + i * (1 * W)) + i_1 * 1]) + (temp[(0 + i_2 * (1 * 2)) + i_3 * 1]); -# CHECK: } -# CHECK: } -# CHECK: } -# CHECK: } -# CHECK: Free(temp); -)IR"); - - // Now check that the loop still produces the correct result. - std::vector c_data(kW * kH, 0); - cg.call({c_data, kW, kH}); - assertAllEqual(c_data, c_ref); - } -} - -TEST(LoopNest, DISABLED_Conv1d_NH) { - // Lots of stuff is broken here. The computeAt swaps the axes for some odd - // reason. Even without that, the index flattener fails due to "dimensions - // mismatch in flatten index". - - int N = 4; - int H = 256; - int R = 3; - int Pad = 1; - BufHandle IP("input", {H}, kFloat); - - Tensor A = Compute("A", {N, H + 2 * Pad}, [&](Axis n, Axis h) { - auto cond = CompareSelect::make(h, Pad, 1, 0, kLT); - cond = CompareSelect::make(h, H + Pad, 1, cond, kGE); - return ifThenElse(cond, 0.f, IP.load(n, h - Pad)); - }); - Tensor B = Reduce( - "B", - {N, H}, - Sum(), - [&](Axis n, Axis h, Axis r) { return A.load(n, h + r); }, - {R}); - LoopNest l({B}); - checkIR(l.root_stmt(), R"IR( -# CHECK: for (int np = 0; np < 4; np++) { -# CHECK: for (int hp = 0; hp < 258; hp++) { -# CHECK: A[np, hp] = IfThenElse(hp>=257 ? 1 : (hp<1 ? 1 : 0), 0.f, input[np, hp - 1]); -# CHECK: } -# CHECK: } -# CHECK: for (int n = 0; n < 4; n++) { -# CHECK: for (int h = 0; h < 256; h++) { -# CHECK: B[n, h] = float(0); -# CHECK: for (int r = 0; r < 3; r++) { -# CHECK: B[n, h] = ReduceOp((B[n, h]) + (A(n, h + r)), reduce_args={r}); -# CHECK: } -# CHECK: } -# CHECK: } -)IR"); - std::vector loops = l.getAllLoopNestsWritingToBuf(B.buf()).at(0); - LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]); - // FIXME: The current IR is totally broken. The body of the inlined loop is: - - // temp[idx0, idx1] = IfThenElse(idx0 + n>=257 ? 1 : (idx0 + n<1 ? 1 : 0), - // 0.f, input[idx1 + 0, (idx0 + n) - 1]); - - // Which seems to mix up the axes. The CHECK below is my best guess at what - // the input "should" look like - - checkIR(l.root_stmt(), R"IR( -# CHECK: for (int n = 0; n < 4; n++) { -# CHECK: for (int idx0 = 0; idx0 < 1; idx0++) { -# CHECK: for (int idx1 = 0; idx1 < 258; idx1++) { - temp[idx0, idx1] = IfThenElse(idx1>=257 ? 1 : (idx1<1 ? 1 : 0), 0.f, input[n, idx1 - 1]); -# CHECK: } -# CHECK: } -# CHECK: for (int h = 0; h < 256; h++) { -# CHECK: B[n, h] = float(0); -# CHECK: for (int r = 0; r < 3; r++) { -# CHECK: B[n, h] = ReduceOp((B[n, h]) + (temp[0, r + h]), reduce_args={r}); -# CHECK: } -# CHECK: } -# CHECK: } -)IR"); - - l.simplify(); - l.prepareForCodegen(); - StmtPtr s = l.root_stmt(); - - SimpleIREvaluator cg(s, {IP, B}); - // auto At = at::ones({N, H}, at::kFloat); - auto At = at::arange(N * H, at::kFloat).reshape({N, H}); - auto Rt = at::conv1d( - At, at::ones({1, 1, 3}), at::Tensor(), /*stride=*/1, /*padding=*/3); - auto Bt = at::empty_like(Rt); - cg.call({At.data_ptr(), Bt.data_ptr()}); - ASSERT_TRUE(at::allclose(Rt, Bt)); -} - -class LoopOrderHelper : public IRVisitor { - std::stringstream ordering; - - public: - std::string getOrder(StmtPtr s) { - ordering.str(""); - s->accept(this); - return ordering.str(); - } - - void visit(const ForPtr& v) final { - ordering << v->var()->name_hint() << ","; - IRVisitor::visit(v); - } -}; - -TEST(LoopNest, LoopNestReorderAxis1) { - Tensor tensor = - Compute("f", {2, 3}, [](const VarHandle& x, const VarHandle& y) { - return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; - }); - LoopNest l({tensor}); - StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); - - std::vector stmt1_output(6, 0); - SimpleIREvaluator cg(stmt1, {tensor}); - cg.call({stmt1_output}); - - auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::reorderAxis(loops[0], loops[1]); - StmtPtr stmt2 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); - - ASSERT_NE(stmt1, stmt2); - LoopOrderHelper loopOrderHelper; - std::string order1 = loopOrderHelper.getOrder(stmt1); - std::string order2 = loopOrderHelper.getOrder(stmt2); - - ASSERT_EQ(order1, "j,i,"); - ASSERT_EQ(order2, "i,j,"); - - std::vector stmt2_output(6, 0); - SimpleIREvaluator cg2(stmt2, {tensor}); - cg.call({stmt2_output}); - - for (int i = 0; i < 6; ++i) { - ASSERT_EQ(stmt1_output[i], stmt2_output[i]); - } - - // Reorder them back. - loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::reorderAxis(loops[0], loops[1]); - StmtPtr stmt3 = l.root_stmt(); - - std::string order3 = loopOrderHelper.getOrder(stmt3); - ASSERT_EQ(order3, order1); - - std::ostringstream oss1, oss2; - oss1 << *stmt1; - oss2 << *stmt3; - - // Should be identical to the unreordered statement. - ASSERT_EQ(oss1.str(), oss2.str()); -} - -TEST(LoopNest, LoopNestReorderPartialAxes) { - Tensor tensor = Compute( - "f", - {2, 3, 4}, - [](const VarHandle& x, const VarHandle& y, const VarHandle& z) { - return ExprHandle(1.0f) + cast(x) * x + cast(y) * y + - cast(z) * z; - }); - LoopNest l({tensor}); - - LoopOrderHelper loopOrderHelper; - StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); - ASSERT_EQ(loopOrderHelper.getOrder(stmt1), "i,j,k,"); - - std::vector stmt1_output(24, 0); - SimpleIREvaluator cg(stmt1, {tensor}); - cg.call({stmt1_output}); - - auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::reorderAxis(loops[0], loops[1]); - ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "j,i,k,"); - - StmtPtr stmt2 = Stmt::clone(l.root_stmt()); - - std::vector stmt2_output(24, 0); - SimpleIREvaluator cg2(stmt2, {tensor}); - cg2.call({stmt2_output}); - - for (int i = 0; i < 24; ++i) { - ASSERT_EQ(stmt1_output[i], stmt2_output[i]); - } - - loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::reorderAxis(loops[1], loops[2]); - ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "j,k,i,"); - - StmtPtr stmt3 = Stmt::clone(l.root_stmt()); - - std::vector stmt3_output(24, 0); - SimpleIREvaluator cg3(stmt3, {tensor}); - cg3.call({stmt3_output}); - - for (int i = 0; i < 24; ++i) { - ASSERT_EQ(stmt1_output[i], stmt3_output[i]); - } -} - -TEST(LoopNest, LoopNestReorderInternalAxis) { - Tensor tensor = Compute( - "f", - {1, 2, 3, 4}, - [](const VarHandle& w, - const VarHandle& x, - const VarHandle& y, - const VarHandle& z) { - return ExprHandle(1.0f) + w + cast(x) * x + cast(y) * y + - cast(z) * z; - }); - LoopNest l({tensor}); - - LoopOrderHelper loopOrderHelper; - StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); - ASSERT_EQ(loopOrderHelper.getOrder(stmt1), "i,j,k,l,"); - - std::vector stmt1_output(24, 0); - SimpleIREvaluator cg(stmt1, {tensor}); - cg.call({stmt1_output}); - - auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::reorderAxis(loops[2], loops[1]); - ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "i,k,j,l,"); - - StmtPtr stmt2 = l.root_stmt(); - - std::vector stmt2_output(24, 0); - SimpleIREvaluator cg2(stmt2, {tensor}); - cg2.call({stmt2_output}); - - for (int i = 0; i < 24; ++i) { - ASSERT_EQ(stmt1_output[i], stmt2_output[i]); - } -} - -TEST(LoopNest, LoopNestReorderEnclosingAxis) { - Tensor tensor = Compute( - "f", - {1, 2, 3, 4}, - [](const VarHandle& w, - const VarHandle& x, - const VarHandle& y, - const VarHandle& z) { - return ExprHandle(1.0f) + w + cast(x) * x + cast(y) * y + - cast(z) * z; - }); - LoopNest l({tensor}); - - LoopOrderHelper loopOrderHelper; - StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); - - std::vector stmt1_output(24, 0); - SimpleIREvaluator cg(stmt1, {tensor}); - cg.call({stmt1_output}); - - auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::reorderAxis(loops[0], loops[3]); - ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "l,j,k,i,"); - - StmtPtr stmt2 = l.root_stmt(); - - std::vector stmt2_output(24, 0); - SimpleIREvaluator cg2(stmt2, {tensor}); - cg2.call({stmt2_output}); - - for (int i = 0; i < 24; ++i) { - ASSERT_EQ(stmt1_output[i], stmt2_output[i]); - } -} - -TEST(LoopNest, LoopNestReorderSameAxis) { - Tensor tensor = - Compute("f", {2, 3}, [](const VarHandle& x, const VarHandle& y) { - return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; - }); - LoopNest l({tensor}); - StmtPtr stmt1 = Stmt::clone(l.root_stmt()); - - auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::reorderAxis(loops[1], loops[1]); - StmtPtr stmt2 = Stmt::clone(l.root_stmt()); - - std::ostringstream oss, oss2; - oss << *stmt1; - oss2 << *stmt2; - ASSERT_EQ(oss.str(), oss2.str()); -} - -TEST(LoopNest, LoopNestReorderExtraStatements) { - /* We're going for a structure like this: - * for i in ... - * Stmt 1 - * for j in ... - * Stmt 2 - * for k in ... - * Stmt 3 - * Stmt 4 - */ - - Tensor tensor = Compute( - "f", - {2, 3, 4}, - [](const VarHandle& x, const VarHandle& y, const VarHandle& z) { - return ExprHandle(1.0f) + cast(x) * x + cast(y) * y + - cast(z) * z; - }); - LoopNest l({tensor}); - - BufHandle extra("res", {6, 3}, kFloat); - - auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - - VarHandle i = VarHandle(loops[0]->var()); - - StmtPtr store_1 = Store::make(extra, {i, 0}, 1.f); - StmtPtr store_2 = Store::make(extra, {i, 1}, 2.f); - // stmt 3 is the Function body. - StmtPtr store_3 = Store::make(extra, {i, 2}, 4.f); - - loops[0]->body()->prepend_stmt(store_1); - loops[1]->body()->prepend_stmt(store_2); - loops[1]->body()->append_stmt(store_3); - StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); - - std::vector extra1(6, 0); - std::vector res1(24, 0); - SimpleIREvaluator cg(stmt1, {tensor, extra}); - cg.call({res1, extra1}); - - /* Then we reorder loop y and z, we want it to look like: - * - * for i in ... - * Stmt 1 - * for j in ... - * Stmt 2 - * for j_1 in ... - * for k in ... - * Stmt 3 - * for j_2 in ... - * Stmt 4 - * - * We need extra loops because we don't have dependency info about stmt 3 - * and 4. - * - */ - - LoopNest::reorderAxis(loops[1], loops[2]); - StmtPtr stmt2 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); - - // Check the IR we produced - checkIR(stmt2, R"IR( -# CHECK: for -# CHECK: res[i, 0] = 1 -# CHECK: for -# CHECK: res[i, 1] = 2 -# CHECK: for -# CHECK: for -# CHECK: f[ -# CHECK: for -# CHECK: res[i, 2] = 4 -)IR"); - - std::vector extra2(6, 0); - std::vector res2(24, 0); - SimpleIREvaluator cg2(stmt2, {tensor, extra}); - cg2.call({res2, extra2}); - - for (int i = 0; i < 24; ++i) { - ASSERT_EQ(res1[i], res2[i]); - } - for (int i = 0; i < 6; ++i) { - ASSERT_EQ(extra1[i], extra2[i]); - } - - /* Now reorder x and the y above stmt 3: - * - * - * for x in ... - * Stmt 1 - * for y in ... - * Stmt 2 - * - * for y in ... - * for z in ... - * for x in ... - * Stmt 3 - * - * for x in ... - * for y in ... - * Stmt 4 - * - * - */ - loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); - LoopNest::reorderAxis(loops[0], loops[2]); - StmtPtr stmt3 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); - - // Check the IR we produced - checkIR(stmt3, R"IR( -# CHECK: for -# CHECK: res[i, 0] = 1 -# CHECK: for -# CHECK: res[i, 1] = 2 -# CHECK: for -# CHECK: for -# CHECK: for -# CHECK: f[ -# CHECK: for -# CHECK: for -# CHECK: res[i_2, 2] = 4 -)IR"); - - std::vector extra3(6, 0); - std::vector res3(24, 0); - SimpleIREvaluator cg3(stmt3, {tensor, extra}); - cg3.call({res3, extra3}); - - for (int i = 0; i < 24; ++i) { - ASSERT_EQ(res1[i], res3[i]); - } - for (int i = 0; i < 6; ++i) { - ASSERT_EQ(extra1[i], extra3[i]); - } -} - -void LoopNestReorderTestHelper( - bool prepend, - bool append, - int index1, - int index2) { - Tensor c = Compute( - "5d", {2, 3, 2, 3, 2}, [](const std::vector&) { return -1; }); - LoopNest l({c}); - - BufHandle extra("extra", {5}, kInt); - - auto loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); - int j = 0; - for (auto l : loops) { - // Add an increment at each layer of the loop which counts the number of - // times the loop executes. - LoadPtr load = - alloc(extra.node(), std::vector({alloc(j)})); - AddPtr add = alloc(load, alloc(1)); - StmtPtr store = alloc( - extra.node(), std::vector({alloc(j)}), add); - if (prepend) { - l->body()->prepend_stmt(store); - } - if (append) { - l->body()->append_stmt(Stmt::clone(store)); - } - - j++; - } - - StmtPtr stmt1 = Stmt::clone(l.root_stmt()); - - std::vector extra1(5, 0); - std::vector res1(2 * 3 * 2 * 3 * 2, 0); - SimpleIREvaluator cg(stmt1, {c, extra}); - cg.call({res1, extra1}); - - std::vector loopExtents = {2, 3, 2, 3, 2}; - - int expected_loops = 0; - if (prepend) { - expected_loops++; - } - if (append) { - expected_loops++; - } - for (int i = 0; i < 5; ++i) { - expected_loops *= loopExtents[i]; - ASSERT_EQ(extra1[i], expected_loops); - } - - loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); - LoopNest::reorderAxis(loops[index1], loops[index2]); - StmtPtr stmt2 = Stmt::clone(l.root_stmt()); - - std::ostringstream oss, oss2; - oss << *stmt1; - oss2 << *stmt2; - ASSERT_NE(oss.str(), oss2.str()); - - std::vector extra2(5, 0); - std::vector res2(2 * 3 * 2 * 3 * 2, 0); - SimpleIREvaluator cg2(stmt2, {c, extra}); - cg2.call({res2, extra2}); - - expected_loops = 0; - if (prepend) { - expected_loops++; - } - if (append) { - expected_loops++; - } - - for (int i = 0; i < 5; ++i) { - expected_loops *= loopExtents[i]; - ASSERT_EQ(extra2[i], expected_loops); - } - - for (int i = 0; i < 2 * 3 * 2 * 3 * 2; ++i) { - ASSERT_EQ(res2[i], res1[i]); - } -} - -TEST(LoopNest, LoopNestReorderLongStringOfPreOrphans) { - for (int i = 0; i < 5; ++i) { - for (int j = 0; j < 5; ++j) { - // skip noops, since we check the loop isn't the same after reordering. - if (i != j) { - LoopNestReorderTestHelper(true, false, i, j); - } - } - } -} - -TEST(LoopNest, LoopNestReorderLongStringOfPostOrphans) { - for (int i = 0; i < 5; ++i) { - for (int j = 0; j < 5; ++j) { - // skip noops, since we check the loop isn't the same after reordering. - if (i != j) { - LoopNestReorderTestHelper(false, true, i, j); - } - } - } -} - -TEST(LoopNest, LoopNestReorderLongStringFull) { - for (int i = 0; i < 5; ++i) { - for (int j = 0; j < 5; ++j) { - // skip noops, since we check the loop isn't the same after reordering. - if (i != j) { - LoopNestReorderTestHelper(true, true, i, j); - } - } - } -} - -TEST(LoopNest, LoopNestReorderInternalLoopNest) { - const int M = 4; - const int N = 5; - const int K = 6; - BufHandle a_buf("a", {M, N}, kFloat); - BufHandle b_buf("b", {N, K}, kFloat); - BufHandle c_buf("c", {M, N}, kFloat); - BufHandle d_buf("d", {M, K}, kFloat); - - Tensor x = Compute( - "x", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) * b_buf.load(n, k); - }); - Tensor y = Compute( - "y", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k); - }); - Tensor z = Compute( - "z", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return x.load(m, n, k) + y.load(m, n, k); - }); - - LoopNest l({z}, {x, y, z}); - ForPtr a = l.getAllLoopNestsWritingToBuf(y.buf())[0][2]; - ForPtr b = l.getAllLoopNestsWritingToBuf(y.buf())[0][0]; - LoopNest::reorderAxis(a, b); - - l.prepareForCodegen(); - StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); - - // Check the IR we produced has the 3 nests in the right order, but k and m - // swapped in the middle. - checkIR(stmt, R"IR( -# CHECK: < 4 -# CHECK: < 5 -# CHECK: < 6 -# CHECK: < 6 -# CHECK: < 5 -# CHECK: < 4 -# CHECK: < 4 -# CHECK: < 5 -# CHECK: < 6)IR"); - - { - PaddedBuffer a_v(M, N); - PaddedBuffer b_v(N, K); - PaddedBuffer c_v(M, N); - PaddedBuffer d_v(M, K); - - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - a_v(i, j) = i * i; - } - } - for (int i = 0; i < N; i++) { - for (int j = 0; j < K; j++) { - b_v(i, j) = j * j; - } - } - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - c_v(i, j) = i + j; - } - } - for (int i = 0; i < M; i++) { - for (int j = 0; j < K; j++) { - d_v(i, j) = i * j; - } - } - - PaddedBuffer z_v(M, N, K); - PaddedBuffer z_ref(M, N, K); - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - for (int k = 0; k < K; k++) { - z_ref(m, n, k) = a_v(m, n) * b_v(n, k) * 2 + c_v(m, n) * d_v(m, k); - } - } - } - - SimpleIREvaluator eval(stmt, {a_buf, b_buf, c_buf, d_buf, z}); - eval(a_v, b_v, c_v, d_v, z_v); - ExpectAllNear(z_v, z_ref, 1e-5); - } -} - -TEST(LoopNest, OuterLoopVectorization) { - Tensor tensor = - Compute("f", {8, 8}, [](const VarHandle& x, const VarHandle& y) { - return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; - }); - LoopNest l({tensor}); - - ASSERT_TRUE( - LoopNest::vectorize(l.getAllLoopNestsWritingToBuf(tensor.buf())[0][0])); - - StmtPtr root_stmt = l.root_stmt(); - BlockPtr outer_block = to(root_stmt); - ASSERT_NE(outer_block, nullptr); - while (BlockPtr inner_block = to(outer_block->front())) { - outer_block = inner_block; - } - - // Verify that we have only a single loop level remaining after - // vectorization. - ASSERT_EQ(outer_block->nstmts(), 1); - ForPtr for_loop = to(outer_block->front()); - ASSERT_NE(for_loop, nullptr); - BlockPtr for_body = for_loop->body(); - ASSERT_EQ(for_body->nstmts(), 1); - ASSERT_EQ(to(for_body->front()), nullptr); -} - -TEST(LoopNest, VectorizeLoopNotNormalized) { - // Input IR: - // for (int i = 0; i < 10; i++) { - // for (int j = 1; j < 5; j++) { - // A[i,j] = i * j; - // } - // } - BufHandle a_buf("A", {10, 5}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); - auto inner_for = For::make(j, 1, 5, for_body); - auto outer_for = For::make(i, 0, 10, inner_for); - auto block = Block::make({outer_for}); - LoopNest l(block, {a_buf.node()}); - - ASSERT_TRUE(LoopNest::vectorize(inner_for)); - ASSERT_EQ(outer_for->body()->nstmts(), 1); - ASSERT_EQ(to(outer_for->body()->front()), nullptr); -} - -namespace { - -std::string constantUpperBoundLoopIR(int upper_bound_val) { - ExprHandle upper_bound(upper_bound_val); - Tensor A = - Compute("A", {upper_bound}, [&](const VarHandle& x) { return x * 2; }); - LoopNest l({A}); - std::vector loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; - StmtPtr unrolled = nullptr; - LoopNest::fullUnroll(loops[0], &unrolled); - std::ostringstream oss; - oss << *unrolled; - return oss.str(); -} - -} // namespace - -TEST(LoopNest, Unroll) { - const std::string actual = constantUpperBoundLoopIR(3); - const std::string& verification_pattern = - R"IR( -# CHECK: A[0] = 0; -# CHECK: A[1] = 2; -# CHECK: A[2] = 4)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, actual); -} - -TEST(LoopNest, UnrollOuter) { - ExprHandle outer_bound(3); - ExprHandle inner_bound(4); - Tensor A = Compute( - "A", - {outer_bound, inner_bound}, - [&](const VarHandle& x, const VarHandle& y) { return x + y; }); - LoopNest l({A}); - std::vector loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; - StmtPtr unrolled = nullptr; - LoopNest::fullUnroll(loops[0], &unrolled); - checkIR(unrolled, R"IR( -# CHECK: for (int i = 0; i < 4; i++) { -# CHECK: A[0, i] = i; -# CHECK: } -# CHECK: for (int i = 0; i < 4; i++) { -# CHECK: A[1, i] = i + 1; -# CHECK: } -# CHECK: for (int i = 0; i < 4; i++) { -# CHECK: A[2, i] = i + 2; -# CHECK: })IR"); -} - -TEST(LoopNest, UnrollInner) { - ExprHandle outer_bound(3); - ExprHandle inner_bound(4); - Tensor A = Compute( - "A", - {outer_bound, inner_bound}, - [&](const VarHandle& x, const VarHandle& y) { return x + y; }); - LoopNest l({A}); - std::vector loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; - StmtPtr unrolled = nullptr; - LoopNest::fullUnroll( - static_to(loops[0]->body()->stmts().front()), &unrolled); - checkIR(loops[0], R"IR( -# CHECK: for (int i = 0; i < 3; i++) { -# CHECK: A[i, 0] = i; -# CHECK: A[i, 1] = i + 1; -# CHECK: A[i, 2] = i + 2; -# CHECK: A[i, 3] = i + 3; -# CHECK: })IR"); -} - -TEST(LoopNest, UnrollMultipleStatements) { - const int kTotalSize = 3; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - - VarHandle x("x", kInt); - auto f = For::make( - x, - 0, - kTotalSize, - Block::make( - {Store::make(a_buf, {x}, x * 2), - Store::make(b_buf, {x}, Load::make(a_buf, {x}))})); - auto parent_block = Block::make({f}); - StmtPtr unrolled = nullptr; - LoopNest::fullUnroll(f, &unrolled); - checkIR(unrolled, R"IR( -# CHECK: A[0] = 0; -# CHECK: B[0] = A[0]; -# CHECK: A[1] = 2; -# CHECK: B[1] = A[1]; -# CHECK: A[2] = 4 -# CHECK: B[2] = A[2];)IR"); -} - -TEST(LoopNest, UnrollNonLiteralConstantBounds) { - // Input IR: - // for (int i = 2 - 1; i < 12 / 3; i++) { - // for (int j = 0; j < 4; j++) { - // A[i,j] = i * j; - // } - // } - BufHandle a_buf("A", {3, 4}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); - auto inner_for = For::make(j, 0, 4, for_body); - auto outer_for = For::make( - i, - IntImm::make(2) - IntImm::make(1), - IntImm::make(12) / IntImm::make(3), - inner_for); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto b = Block::make({outer_for}); - - std::vector loops = {outer_for, inner_for}; - StmtPtr unrolled = nullptr; - LoopNest::fullUnroll(loops[0], &unrolled); - checkIR(unrolled, R"IR( -# CHECK: for (int j = 0; j < 4; j++) { -# CHECK: A[1, j] = j; -# CHECK: } -# CHECK: for (int j = 0; j < 4; j++) { -# CHECK: A[2, j] = 2 * j; -# CHECK: } -# CHECK: for (int j = 0; j < 4; j++) { -# CHECK: A[3, j] = 3 * j; -# CHECK: })IR"); -} - -TEST(LoopNest, UnrollNonConstantBounds) { - // Input IR: - // for (int i = 0; i < M; i++) { - // for (int j = 0; j < N; j++) { - // A[i, j] = i * j; - // } - // } - VarHandle M("M", kInt); - VarHandle N("N", kInt); - BufHandle a_buf("A", {M, N}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); - auto inner_for = For::make(j, 0, N, for_body); - auto outer_for = For::make(i, 0, M, inner_for); - auto block = Block::make({outer_for}); - LoopNest l(block, {a_buf.node()}); - - LoopNest::unroll(inner_for, 8); - l.simplify(); - checkIR(l.root_stmt(), R"IR( - # CHECK: for (int i = 0; i < M; i++) { - # CHECK: for (int j_outer = 0; j_outer < N / 8; j_outer++) { - # CHECK: A[i, 8 * j_outer] = - # CHECK: A[i, 8 * j_outer + 1] = - # CHECK: A[i, 2 * (4 * j_outer + 1)] = - # CHECK: A[i, 8 * j_outer + 3] = - # CHECK: A[i, 4 * (2 * j_outer + 1)] = - # CHECK: A[i, 8 * j_outer + 5] = - # CHECK: A[i, 8 * j_outer + 6] = - # CHECK: A[i, 8 * j_outer + 7] = - # CHECK: } - # CHECK: for (int j_tail = 0; j_tail < N % 8; j_tail++) { - # CHECK: A[i, 8 * (N / 8) + j_tail] = - # CHECK: } - # CHECK: } - )IR"); -} - -TEST(LoopNest, UnrollByFactorsLessThan2) { - // Input IR: - // for (int i = 0; i < M; i++) { - // for (int j = 0; j < N; j++) { - // A[i, j] = i * j; - // } - // } - VarHandle M("M", kInt); - VarHandle N("N", kInt); - BufHandle a_buf("A", {M, N}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); - auto inner_for = For::make(j, 0, N, for_body); - auto outer_for = For::make(i, 0, M, inner_for); - auto block = Block::make({outer_for}); - LoopNest l(block, {a_buf.node()}); - - // Unrolling by factor = 1 should do nothing. - LoopNest::unroll(inner_for, 1); - checkIR(l.root_stmt(), R"IR( - # CHECK: for (int i = 0; i < M; i++) { - # CHECK: for (int j = 0; j < N; j++) { - # CHECK: A[i, j] = - # CHECK: } - # CHECK: } - )IR"); - - // Unrolling by factor = 0 should do nothing. - LoopNest::unroll(inner_for, 0); - checkIR(l.root_stmt(), R"IR( - # CHECK: for (int i = 0; i < M; i++) { - # CHECK: for (int j = 0; j < N; j++) { - # CHECK: A[i, j] = - # CHECK: } - # CHECK: } - )IR"); - - // Unrolling by negative factor should do nothing. - LoopNest::unroll(inner_for, -2); - checkIR(l.root_stmt(), R"IR( - # CHECK: for (int i = 0; i < M; i++) { - # CHECK: for (int j = 0; j < N; j++) { - # CHECK: A[i, j] = - # CHECK: } - # CHECK: } - )IR"); -} - -TEST(LoopNest, UnrollByFactorEqualToIters) { - // Input IR: - // for (int i = 0; i < 5; i++) { - // A[i] = i * i; - // } - BufHandle a_buf("A", {5}, kInt); - VarHandle i("i", kInt); - auto for_body = Block::make({Store::make(a_buf, {i}, i * i)}); - auto for_loop = For::make(i, 0, 5, for_body); - auto block = Block::make({for_loop}); - LoopNest l(block, {a_buf.node()}); - - LoopNest::unroll(for_loop, 5); - checkIR(l.root_stmt(), R"IR( - # CHECK: for (int i_outer = 0; i_outer < (5 - 0) / 5; i_outer++) - # CHECK: A[5 * i_outer] - # CHECK: A[5 * i_outer + 1] - # CHECK: A[5 * i_outer + 2] - # CHECK: A[5 * i_outer + 3] - # CHECK: A[5 * i_outer + 4] - )IR"); -} - -TEST(LoopNest, UnrollEmpty) { - const std::string actual = constantUpperBoundLoopIR(0); - const std::string& verification_pattern = R"IR( -# CHECK-NOT: A[ - )IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, actual); -} - -TEST(LoopNest, NoUnroll) { - VarHandle upper_bound("N", kInt); - Tensor A = - Compute("A", {upper_bound}, [&](const VarHandle& x) { return x * 2; }); - LoopNest l({A}); - std::vector loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; - StmtPtr unrolled = nullptr; - ASSERT_THROWS_WITH( - LoopNest::fullUnroll(loops[0], &unrolled), "non-constant loop"); -} - -TEST(LoopNest, UnrollWithLet) { - const int kTotalSize = 3; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - - VarHandle e("e", kInt); - VarHandle x("x", kInt); - auto f = For::make( - x, - 0, - kTotalSize, - Block::make( - {Let::make(e, 7), - Store::make(a_buf, {x}, e), - Store::make(b_buf, {x}, e + 1)})); - auto parent_block = Block::make({f}); - StmtPtr unrolled = nullptr; - LoopNest::fullUnroll(f, &unrolled); - std::ostringstream oss; - oss << *unrolled; - const std::string& verification_pattern = - R"IR( -# CHECK: int e = 7; -# CHECK: A[0] = e; -# CHECK: B[0] = e + 1; -# CHECK: A[1] = e; -# CHECK: B[1] = e + 1; -# CHECK: A[2] = e; -# CHECK: B[2] = e + 1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - std::vector a_v(kTotalSize, 0); - std::vector b_v(kTotalSize, 0); - SimpleIREvaluator eval(unrolled, {a_buf, b_buf}); - eval(a_v, b_v); - for (int i = 0; i < kTotalSize; ++i) { - ASSERT_EQ(a_v[i], 7); - ASSERT_EQ(b_v[i], 8); - } -} - -TEST(LoopNest, IsNormalized) { - // Input IR: - // for (int i = 50; i < 100; i++) { - // A[i] = B[i]; - // } - BufHandle a_buf("A", {ExprHandle(100)}, kInt); - BufHandle b_buf("B", {ExprHandle(100)}, kInt); - VarHandle i("i", kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto for_stmt = - For::make(i, 50, 100, Store::make(a_buf, {i}, Load::make(b_buf, {i}))); - Block::make({for_stmt}); - ASSERT_FALSE(LoopNest::isNormalized(for_stmt)); - - for_stmt->set_start(alloc(0)); - ASSERT_TRUE(LoopNest::isNormalized(for_stmt)); - - VarHandle N("N", kInt); - for_stmt->set_start(N.node()); - ASSERT_FALSE(LoopNest::isNormalized(for_stmt)); -} - -TEST(LoopNest, NormalizeStartPositive) { - // Input IR: - // for (int x = 50; x < 100; x++) { - // A[x] = B[x]; - // B[x] = x * 2; - // } - const int kTotalSize = 50; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - VarHandle x("x", kInt); - auto for_body = Block::make( - {Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})), - Store::make(b_buf, {x}, x * 2)}); - auto for_stmt = For::make(x, 50, 100, for_body); - Block::make({for_stmt}); - - LoopNest::normalize(for_stmt); - - auto result = IRSimplifier::simplify(for_stmt); - std::ostringstream oss; - oss << *result; - const std::string& expected_ir = - R"IR( - # CHECK: for (int x = 0; x < 50; x++) { - # CHECK: A[x + 50] = B[x + 50]; - # CHECK: B[x + 50] = 2 * (x + 50); - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} - -TEST(LoopNest, NormalizeStartNegative) { - // Input IR: - // for (int x = -50; x < 100; x++) { - // A[x + 50] = B[x + 50]; - // B[x + 50] = x * 2; - // } - const int kTotalSize = 150; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - VarHandle x("x", kInt); - auto for_body = Block::make( - {Store::make(a_buf, {x + 50}, Load::make(kInt, b_buf, {x + 50})), - Store::make(b_buf, {x + 50}, x * 2)}); - auto for_stmt = For::make(x, -50, 100, for_body); - Block::make({for_stmt}); - - LoopNest::normalize(for_stmt); - - auto result = IRSimplifier::simplify(for_stmt); - std::ostringstream oss; - oss << *result; - const std::string& expected_ir = - R"IR( - # CHECK: for (int x = 0; x < 150; x++) { - # CHECK: A[x] = B[x]; - # CHECK: B[x] = 2 * (x - 50); - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} - -TEST(LoopNest, NormalizeStartZero) { - // Input IR: - // for (int x = 0; x < 100; x++) { - // A[x] = B[x]; - // B[x] = x * 2; - // } - // Should not be modified. - - const int kTotalSize = 100; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - VarHandle x("x", kInt); - auto for_body = Block::make( - {Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})), - Store::make(b_buf, {x}, x * 2)}); - auto for_stmt = For::make(x, 0, 100, for_body); - Block::make({for_stmt}); - - LoopNest::normalize(for_stmt); - - auto result = IRSimplifier::simplify(for_stmt); - std::ostringstream oss; - oss << *result; - const std::string& expected_ir = - R"IR( - # CHECK: for (int x = 0; x < 100; x++) { - # CHECK: A[x] = B[x]; - # CHECK: B[x] = 2 * x; - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} - -TEST(LoopNest, NormalizeStartVariable) { - // Input IR: - // for (int x = y; x < 100; x++) { - // A[x] = B[x]; - // B[x] = x * 2; - // } - - const int kTotalSize = 100; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - auto for_body = Block::make( - {Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})), - Store::make(b_buf, {x}, x * 2)}); - auto for_stmt = For::make(x, y, 100, for_body); - auto parent_block = Block::make({for_stmt}); - - LoopNest::normalize(for_stmt); - - auto result = IRSimplifier::simplify(for_stmt); - std::ostringstream oss; - oss << *result; - const std::string& expected_ir = - R"IR( - # CHECK: for (int x = 0; x < 100 - y; x++) { - # CHECK: A[x + y] = B[x + y]; - # CHECK: B[x + y] = 2 * (x + y); - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} - -TEST(LoopNest, NormalizeOnNestedOuterLoop) { - // Input IR: - // for (int x = 50; x < 100; x++) { - // for (int y = 10; y < 100; y++) { - // A[x] = A[x] + B[y] + y * 2; - // } - // } - - BufHandle a_buf("A", {ExprHandle(50)}, kInt); - BufHandle b_buf("B", {ExprHandle(100)}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - auto inner_for_body = Store::make( - a_buf, {x}, Load::make(a_buf, {x}) + Load::make(b_buf, {y}) + y * 2); - auto inner_for = For::make(y, 10, 100, inner_for_body); - auto for_stmt = For::make(x, 50, 100, inner_for); - Block::make({for_stmt}); - - LoopNest::normalize(for_stmt); - - auto result = IRSimplifier::simplify(for_stmt); - std::ostringstream oss; - oss << *result; - const std::string& expected_ir = - R"IR( - # CHECK: for (int x = 0; x < 50; x++) { - # CHECK: for (int y = 10; y < 100; y++) { - # CHECK: A[x + 50] = ((A[x + 50]) + (B[y])) + 2 * y; - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} - -TEST(LoopNest, NormalizeOnNestedInnerLoop) { - // Input IR: - // for (int x = 50; x < 100; x++) { - // for (int y = 10; y < 100; y++) { - // A[x] = A[x] + B[y] + y * 2; - // } - // } - - BufHandle a_buf("A", {ExprHandle(50)}, kInt); - BufHandle b_buf("B", {ExprHandle(100)}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - auto inner_for_body = Store::make( - a_buf, {x}, Load::make(a_buf, {x}) + Load::make(b_buf, {y}) + y * 2); - auto inner_for = For::make(y, 10, 100, inner_for_body); - auto for_stmt = For::make(x, 50, 100, inner_for); - Block::make({for_stmt}); - - LoopNest::normalize(inner_for); - - auto result = IRSimplifier::simplify(for_stmt); - std::ostringstream oss; - oss << *result; - const std::string& expected_ir = - R"IR( - # CHECK: for (int x = 50; x < 100; x++) { - # CHECK: for (int y = 0; y < 90; y++) { - # CHECK: A[x] = (((A[x]) + (B[y + 10])) + 2 * y) + 20; - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} - -TEST(LoopNest, NormalizeAndSplitWithTail) { - // Create a dummy tensor to construct LoopNest. - ExprHandle n(100); - BufHandle a("a", {n}, kFloat); - Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); }); - LoopNest l({b}); - - // Input IR: - // for (int x = 5; x < 10; x++) { - // A[x] = x * 2; - // } - const int kTotalSize = 5; - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); - VarHandle x("x", kInt); - auto for_stmt = For::make(x, 5, 10, Store::make(a_buf, {x}, x * 2)); - auto parent_block = Block::make({for_stmt}); - - LoopNest::normalize(for_stmt); - - ForPtr x_inner; - ForPtr x_tail; - LoopNest::splitWithTail(for_stmt, 10, &x_inner, &x_tail); - - auto x_outer_result = IRSimplifier::simplify(for_stmt); - std::ostringstream oss_outer; - oss_outer << *x_outer_result; - const std::string& expected_outer_ir = - R"IR( - # CHECK: { - # CHECK: } - )IR"; - torch::jit::testing::FileCheck().run(expected_outer_ir, oss_outer.str()); - - auto x_tail_result = IRSimplifier::simplify(x_tail); - std::ostringstream oss_tail; - oss_tail << *x_tail_result; - const std::string& expected_tail_ir = - R"IR( - # CHECK: for (int x_tail = 0; x_tail < 5; x_tail++) { - # CHECK: A[x_tail + 5] = 2 * (x_tail + 5); - )IR"; - torch::jit::testing::FileCheck().run(expected_tail_ir, oss_tail.str()); -} - -TEST(LoopNest, NotNormalizeAndSplitWithTail) { - // Create a dummy tensor to construct LoopNest. - ExprHandle n(100); - BufHandle a("a", {n}, kFloat); - Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); }); - LoopNest l({b}); - - // Input IR: - // for (int x = 5; x < 15; x++) { - // A[x] = x * 2; - // } - const int kTotalSize = 10; - BufHandle a_buf("A", {kTotalSize}, kInt); - VarHandle x("x", kInt); - auto for_stmt = For::make(x, 5, 15, Store::make(a_buf, {x}, x * 2)); - auto parent_block = Block::make({for_stmt}); - - ForPtr x_inner; - ForPtr x_tail; - LoopNest::splitWithTail(for_stmt, 8, &x_inner, &x_tail); - - auto x_outer_result = IRSimplifier::simplify(for_stmt); - std::ostringstream oss_outer; - oss_outer << *x_outer_result; - const std::string& expected_outer_ir = - R"IR( - # CHECK: { - # CHECK: } - )IR"; - torch::jit::testing::FileCheck().run(expected_outer_ir, oss_outer.str()); - - auto x_tail_result = IRSimplifier::simplify(x_tail); - std::ostringstream oss_tail; - oss_tail << *x_tail_result; - const std::string& expected_tail_ir = - R"IR( - # CHECK: for (int x_tail = 0; x_tail < 2; x_tail++) { - # CHECK: A[x_tail + 13] = 2 * (x_tail + 13); - )IR"; - torch::jit::testing::FileCheck().run(expected_tail_ir, oss_tail.str()); -} - -TEST(LoopNest, FlattenSimpleLoopNest2D) { - // Input IR: - // for (int i = 0; i < 10; i++) { - // for (int j = 0; j < 5; j++) { - // A[i,j] = i * j; - // } - // } - BufHandle a_buf("A", {10, 5}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); - auto inner_for = For::make(j, 0, 5, for_body); - auto outer_for = For::make(i, 0, 10, inner_for); - auto parent_block = Block::make({outer_for}); - - std::vector loops = {outer_for, inner_for}; - ForPtr flattened = nullptr; - ASSERT_TRUE(LoopNest::flatten(loops, &flattened)); - ASSERT_EQ(flattened, loops.front()); - - auto result = IRSimplifier::simplify(flattened); - std::ostringstream oss; - oss << *result; - const std::string& expected_ir = - R"IR( - # CHECK: for (int i_flat = 0; i_flat < 50; i_flat++) { - # CHECK: A[i_flat / 5, i_flat % 5] = - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); - - { - SimpleIREvaluator eval1(loops[0], {a_buf}); - PaddedBuffer inp1(10, 5); - eval1(inp1); - SimpleIREvaluator eval2(flattened, {a_buf}); - PaddedBuffer inp2(10, 5); - eval2(inp2); - ExpectAllNear(inp1, inp2, 1e-5); - } -} - -TEST(LoopNest, FlattenSimpleLoopNest3D) { - // Input IR: - // for (int i = 0; i < 10; i++) { - // for (int j = 0; j < 5; j++) { - // for (int k = 0; k < 7; k++) { - // A[i,j,k] = i + j * k; - // } - // } - // } - BufHandle a_buf("A", {10, 5, 7}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto for_body = Block::make({Store::make(a_buf, {i, j, k}, i + j * k)}); - auto for1 = For::make(k, 0, 7, for_body); - auto for2 = For::make(j, 0, 5, for1); - auto for3 = For::make(i, 0, 10, for2); - auto parent_block = Block::make({for3}); - - std::vector loops = {for3, for2, for1}; - ForPtr flattened = nullptr; - ASSERT_TRUE(LoopNest::flatten(loops, &flattened)); - ASSERT_EQ(flattened, loops.front()); - - auto result = IRSimplifier::simplify(flattened); - std::ostringstream oss; - oss << *result; - const std::string& expected_ir = - R"IR( - # CHECK: for (int i_flat = 0; i_flat < 350; i_flat++) { - # CHECK: A[i_flat / 35, (i_flat / 7) % 5, i_flat % 7] = - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); - - { - SimpleIREvaluator eval1(loops[0], {a_buf}); - PaddedBuffer inp1(10, 5, 7); - eval1(inp1); - SimpleIREvaluator eval2(flattened, {a_buf}); - PaddedBuffer inp2(10, 5, 7); - eval2(inp2); - ExpectAllNear(inp1, inp2, 1e-5); - } -} - -TEST(LoopNest, FlattenLoopNestAfterNormalize) { - // Input IR: - // for (int i = 2; i < 10; i++) { - // for (int j = 3; j < 15; j++) { - // A[i - 2,j - 3] = i * j; - // } - // } - BufHandle a_buf("A", {8, 12}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto for_body = Block::make({Store::make(a_buf, {i - 2, j - 3}, i * j)}); - auto inner_for = For::make(j, 3, 15, for_body); - auto outer_for = For::make(i, 2, 10, inner_for); - auto parent_block = Block::make({outer_for}); - - std::vector loops = {outer_for, inner_for}; - ForPtr flattened = nullptr; - ASSERT_TRUE(LoopNest::flatten(loops, &flattened)); - ASSERT_EQ(flattened, loops.front()); - - auto result = IRSimplifier::simplify(flattened); - std::ostringstream oss; - oss << *result; - const std::string& expected_ir = - R"IR( - # CHECK: for (int i_flat = 0; i_flat < 96; i_flat++) { - # CHECK: A[i_flat / 12, i_flat % 12] = - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); - - { - SimpleIREvaluator eval1(loops[0], {a_buf}); - PaddedBuffer inp1(8, 12); - eval1(inp1); - SimpleIREvaluator eval2(flattened, {a_buf}); - PaddedBuffer inp2(8, 12); - eval2(inp2); - ExpectAllNear(inp1, inp2, 1e-5); - } -} - -TEST(LoopNest, FlattenLoopNestWithNonLiteralConstantBounds) { - // Input IR: - // for (int i = 0; i < 15-5; i++) { - // for (int j = 0; j < 20/4; j++) { - // A[i,j] = i * j; - // } - // } - BufHandle a_buf("A", {10, 5}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); - auto inner_for = - For::make(j, 0, IntImm::make(20) / IntImm::make(4), for_body); - auto outer_for = - For::make(i, 0, IntImm::make(15) - IntImm::make(5), inner_for); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto b = Block::make({outer_for}); - - std::vector loops = {outer_for, inner_for}; - ForPtr flattened = nullptr; - ASSERT_TRUE(LoopNest::flatten(loops, &flattened)); - ASSERT_EQ(flattened, loops.front()); - - auto result = IRSimplifier::simplify(flattened); - checkIR(result, R"IR( - # CHECK: for (int i_flat = 0; i_flat < 50; i_flat++) { - # CHECK: A[i_flat / 5, i_flat % 5] = - )IR"); - - { - SimpleIREvaluator eval1(loops[0], {a_buf}); - PaddedBuffer inp1(10, 5); - eval1(inp1); - SimpleIREvaluator eval2(flattened, {a_buf}); - PaddedBuffer inp2(10, 5); - eval2(inp2); - ExpectAllNear(inp1, inp2, 1e-5); - } -} - -TEST(LoopNest, FlattenImperfectLoopNest) { - // Input IR: - // for (int i = 0; i < 10; i++) { - // A[i, i] = 0; - // for (int j = 0; j < 15; j++) { - // A[i,j] = i * j; - // } - // } - // Do not flatten. - - BufHandle a_buf("A", {10, 15}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); - auto inner_for = For::make(j, 0, 15, for_body); - auto outer_for = For::make( - i, 0, 10, Block::make({Store::make(a_buf, {i, i}, 0), inner_for})); - auto par = Block::make({outer_for}); - HashProvider hasher; - auto hash_before = hasher.hash(par); - - std::vector loops = {outer_for, inner_for}; - ForPtr flattened = nullptr; - ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); - ASSERT_EQ(flattened, nullptr); - auto hash_after = hasher.hash(par); - ASSERT_EQ(hash_before, hash_after); -} - -TEST(LoopNest, FlattenReductionLoopNest) { - // Input IR: - // for (int i = 0; i < 10; i++) { - // S[i] = 0; - // for (int j = 0; j < 15; j++) { - // S[i] = S[i] + A[i,j]; - // } - // } - // Do not flatten. - - BufHandle a_buf("A", {10, 15}, kInt); - BufHandle s_buf("S", {10}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto for_body = Block::make({Store::make( - s_buf, {i}, Load::make(s_buf, {i}) + Load::make(a_buf, {i, j}))}); - auto inner_for = For::make(j, 0, 15, for_body); - auto outer_for = - For::make(i, 0, 10, Block::make({Store::make(s_buf, {i}, 0), inner_for})); - auto par = Block::make({outer_for}); - HashProvider hasher; - auto hash_before = hasher.hash(par); - - std::vector loops = {outer_for, inner_for}; - ForPtr flattened = nullptr; - ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); - ASSERT_EQ(flattened, nullptr); - auto hash_after = hasher.hash(par); - ASSERT_EQ(hash_before, hash_after); -} - -TEST(LoopNest, FlattenReductionLoopNestFromTensor) { - const int M = 3; - const int N = 7; - VarHandle m("m", kInt); - VarHandle n("n", kInt); - BufHandle b("b", {m, n}, kFloat); - Tensor c = Reduce("sum", {M}, Sum(), b, {N}); - LoopNest loop({c}); - HashProvider hasher; - auto hash_before = hasher.hash(loop.root_stmt()); - - auto loops = loop.getAllLoopNestsWritingToBuf(c.buf())[1]; - ForPtr flattened = nullptr; - ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); - ASSERT_EQ(flattened, nullptr); - auto hash_after = hasher.hash(loop.root_stmt()); - ASSERT_EQ(hash_before, hash_after); -} - -TEST(LoopNest, FlattenIncorrectLoopsAsInput) { - // Input IR: - // for (int i = 0; i < 10; i++) { - // for (int j = 0; j < 5; j++) { - // A[i,j] = i * j; - // } - // } - // for (int x = 0; x < 10; x++) { - // for (int y = 0; y < 5; y++) { - // A[x,y] = A[x,y] + x + y; - // } - // } - // Flatten({For_i, For_y}) => should not succeed - - BufHandle a_buf("A", {10, 5}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j)}); - auto inner_for1 = For::make(j, 0, 5, for_body1); - auto outer_for1 = For::make(i, 0, 10, inner_for1); - auto for_body2 = Block::make( - {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)}); - auto inner_for2 = For::make(y, 0, 5, for_body2); - auto outer_for2 = For::make(x, 0, 10, inner_for2); - auto par = Block::make({outer_for1, outer_for2}); - HashProvider hasher; - auto hash_before = hasher.hash(par); - - std::vector loops = {outer_for1, inner_for2}; - ForPtr flattened = nullptr; - ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); - ASSERT_EQ(flattened, nullptr); - auto hash_after = hasher.hash(par); - ASSERT_EQ(hash_before, hash_after); -} - -TEST(LoopNest, DetectInlineRankMismatch) { - const int kTotalSize = 8; - - BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); - Tensor a = Compute( - "a", {kTotalSize}, [&](const VarHandle& i) { return a_buf.load(i); }); - Tensor reshape = Compute( - "reshape", - {kTotalSize / 2, 2}, - [&](const VarHandle& i, const VarHandle& j) { return a.load(i, j); }); - LoopNest l({reshape}, {a, reshape}); - ASSERT_FALSE(l.computeInline(l.getLoopBodyFor(a))); -} - -TEST(LoopNest, CacheReadsSimple) { - Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { - return i * j; - }); - Tensor B = - Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 30, j + 3); - }); - Tensor C = - Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); - }); - - LoopNest l({B, C}, {A, B, C}); - StmtPtr j_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][1]; - LoopNest::cacheAccesses(A.buf(), "A_local", j_loop); - - l.prepareForCodegen(); - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg(result, {B, C}); - result = cg.stmt(); - - // just this once: verify the whole thing. - checkIR(result, R"IR( -#CHECK: Allocate(A); // dtype=int, dims=[64, 64] -#CHECK: Allocate(A_local); // dtype=int, dims=[1, 10] -#CHECK: for (int i -#CHECK: for (int j -#CHECK: A[ -#CHECK: } -#CHECK: } -#CHECK: for (int i_1 -#CHECK: for (int j_1 -#CHECK: A_local[j_1] = A[ -#CHECK: } -#CHECK: for (int j_2 -#CHECK: B[j_2 + 10 * i_1] = A_local[j_2]; -#CHECK: } -#CHECK: } -#CHECK: for (int i_2 -#CHECK: for (int j_3 -#CHECK: C[ -#CHECK: } -#CHECK: } -#CHECK: Free(A_local); -#CHECK: Free(A); - )IR"); - - std::vector b_data(200, 0); - std::vector c_data(200, 0); - cg.call({b_data, c_data}); - - std::vector b_ref(200, 0); - std::vector c_ref(200, 0); - - for (int i = 0; i < 20; ++i) { - for (int j = 0; j < 10; ++j) { - b_ref[i * 10 + j] = (i + 30) * (j + 3); - c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); - } - } - - assertAllEqual(b_data, b_ref); - assertAllEqual(c_data, c_ref); -} - -TEST(LoopNest, CacheReadsOuter) { - Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { - return i * j; - }); - Tensor B = - Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 30, j + 40) + A.load(i + 31, j + 41); - }); - Tensor C = - Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); - }); - - LoopNest l({B, C}, {A, B, C}); - StmtPtr i_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][0]; - LoopNest::cacheAccesses(A.buf(), "A_local", i_loop); - - l.prepareForCodegen(); - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg(result, {B, C}); - result = cg.stmt(); - - checkIR(result, R"IR( -#CHECK: Allocate(A_local); // dtype=int, dims=[21, 11] -#CHECK: A_local[j_1 + 11 * i_1] = -#CHECK: B[j_2 + 10 * i_2] = (A_local[j_2 + 11 * i_2]) + (A_local[(j_2 + 11 * i_2) + 12]); - )IR"); - - std::vector b_data(200, 0); - std::vector c_data(200, 0); - cg.call({b_data, c_data}); - - std::vector b_ref(200, 0); - std::vector c_ref(200, 0); - - for (int i = 0; i < 20; ++i) { - for (int j = 0; j < 10; ++j) { - b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41); - c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); - } - } - - assertAllEqual(b_data, b_ref); - assertAllEqual(c_data, c_ref); -} - -TEST(LoopNest, CacheReadsInternal) { - Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { - return i * j; - }); - Tensor B = - Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 30, j + 40) + A.load(i + 31, j + 41); - }); - Tensor C = - Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); - }); - - LoopNest l({B, C}, {A, B, C}); - StmtPtr j_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][1]; - LoopNest::cacheAccesses(A.buf(), "A_local", j_loop); - l.prepareForCodegen(); - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg(result, {B, C}); - result = cg.stmt(); - - checkIR(result, R"IR( -#CHECK: Allocate(A_local); // dtype=int, dims=[2, 11] -#CHECK: A_local[k + 11 * j_1] = -#CHECK: B[j_2 + 10 * i_1] = (A_local[j_2 + 12]) + (A_local[j_2]); - )IR"); - - std::vector b_data(200, 0); - std::vector c_data(200, 0); - cg.call({b_data, c_data}); - - std::vector b_ref(200, 0); - std::vector c_ref(200, 0); - - for (int i = 0; i < 20; ++i) { - for (int j = 0; j < 10; ++j) { - b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41); - c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); - } - } - - assertAllEqual(b_data, b_ref); - assertAllEqual(c_data, c_ref); -} - -TEST(LoopNest, CacheReadsInner) { - Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { - return i * j; - }); - // note im changing the offset of the first arg of the first call to A. - Tensor B = - Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 34, j + 40) + A.load(i + 30, j + 41); - }); - Tensor C = - Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); - }); - - LoopNest l({B, C}, {A, B, C}); - StmtPtr body = l.getLoopBodyFor(B); - LoopNest::cacheAccesses(A.buf(), "A_local", body); - l.prepareForCodegen(); - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg(result, {B, C}); - result = cg.stmt(); - - checkIR(result, R"IR( -#CHECK: Allocate(A_local); // dtype=int, dims=[5, 2] -#CHECK: A_local[l + 2 * k] = -#CHECK: B[j_1 + 10 * i_1] = (A_local[1]) + (A_local[8]); - )IR"); - - std::vector b_data(200, 0); - std::vector c_data(200, 0); - cg.call({b_data, c_data}); - - std::vector b_ref(200, 0); - std::vector c_ref(200, 0); - - for (int i = 0; i < 20; ++i) { - for (int j = 0; j < 10; ++j) { - b_ref[i * 10 + j] = (i + 34) * (j + 40) + (i + 30) * (j + 41); - c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); - } - } - - assertAllEqual(b_data, b_ref); - assertAllEqual(c_data, c_ref); -} - -TEST(LoopNest, CacheWritesSimple) { - Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { - return i * j; - }); - Tensor B = - Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 30, j + 40) + A.load(i + 31, j + 41); - }); - Tensor C = - Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); - }); - - LoopNest l({B, C}, {A, B, C}); - StmtPtr a_loop = l.getAllLoopNestsWritingToBuf(A.buf())[0][1]; - LoopNest::cacheAccesses(A.buf(), "A_local", a_loop); - - l.prepareForCodegen(); - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg(result, {B, C}); - result = cg.stmt(); - - checkIR(result, R"IR( -#CHECK: Allocate(A_local); // dtype=int, dims=[1, 64] -#CHECK: for (int j = 0; j < 64 -#CHECK: A_local[j] = i * j; -#CHECK: for (int j_1 = 0; j_1 < 64 -#CHECK: A[j_1 + 64 * i] = A_local[ -#CHECK: Free(A_local); -#CHECK-NOT: A_local - )IR"); - - std::vector b_data(200, 0); - std::vector c_data(200, 0); - cg.call({b_data, c_data}); - - std::vector b_ref(200, 0); - std::vector c_ref(200, 0); - - for (int i = 0; i < 20; ++i) { - for (int j = 0; j < 10; ++j) { - b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41); - c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); - } - } - - assertAllEqual(b_data, b_ref); - assertAllEqual(c_data, c_ref); -} - -TEST(LoopNest, DeadStoreElimination) { - VarHandle y("y", kInt); - VarHandle x("x_tail", kInt); - BufHandle f("f", {26, 5}, kInt); - BufHandle g("g", {26, 5}, kInt); - ExprHandle x_outer_end = 5; - ExprHandle x_2 = x + x_outer_end * 4; - ForPtr stmt1 = For::make( - x, - 0, - 5, - For::make( - y, - 0, - 5, - Block::make({ - Store::make(f, {x_2, y}, (x_2 + y)), - Store::make(g, {x_2, y}, (x_2 * y)), - }))); - StmtPtr stmt = Block::make({stmt1}); - - // Will eliminate if not used by an output. - LoopNest loop(Stmt::clone(stmt), {f.node()}); - loop.eliminateDeadStores(); - - checkIR(loop.root_stmt(), R"IR( -#CHECK: f[x_tail + 5 * 4, y] -#CHECK-NOT: g[x_tail + 5 * 4, y] - )IR"); - - // But won't eliminate if used by different outputs. - LoopNest loop2(stmt, {f.node(), g.node()}); - loop2.eliminateDeadStores(); - - checkIR(loop2.root_stmt(), R"IR( -#CHECK: f[x_tail + 5 * 4, y] -#CHECK: g[x_tail + 5 * 4, y] - )IR"); -} - -TEST(LoopNest, DeadStoreEliminationWithIntermediates) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - BufHandle f("f", {26 * 5}, kInt); - BufHandle g("g", {26 * 5}, kInt); - BufHandle h("h", {26, 5}, kInt); - ExprHandle x_outer_end = 5; - ExprHandle x_2 = x + x_outer_end * 4; - ForPtr stmt1 = For::make(x, 0, 26 * 5, Store::make(f, {x}, x)); - ForPtr stmt2 = For::make(z, 0, 26 * 5, Store::make(g, {z}, z + 1)); - ForPtr stmt3 = For::make( - x, - 0, - 5, - For::make( - y, - 0, - 5, - Block::make({ - Store::make(h, {x, y}, Load::make(f, {x * y})), - }))); - StmtPtr stmt = Block::make({stmt1, stmt2, stmt3}); - - // Will eliminate the write to g, but not f since it used by the producer of - // h. - LoopNest loop(Stmt::clone(stmt), {h.node()}); - loop.eliminateDeadStores(); - - checkIR(loop.root_stmt(), R"IR( - #CHECK: f[x] = x; - #CHECK-NOT: g[z] = - #CHECK: h[x, y] = f[x * y]; - )IR"); - - // Sanity check won't eliminate if g is an output. - LoopNest loop2(stmt, {h.node(), g.node()}); - loop2.eliminateDeadStores(); - - checkIR(loop2.root_stmt(), R"IR( - #CHECK: f[x] = x; - #CHECK: g[z] = z + 1; - #CHECK: h[x, y] = f[x * y]; - )IR"); -} - -TEST(LoopNest, CompoundTensorSimple) { - BufHandle a_buf("A", {10, 5}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j)}); - auto inner_for1 = For::make(j, 0, 5, for_body1); - auto outer_for1 = For::make(i, 0, 10, inner_for1); - auto for_body2 = Block::make( - {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)}); - auto inner_for2 = For::make(y, 0, 5, for_body2); - auto outer_for2 = For::make(x, 0, 10, inner_for2); - BlockPtr body = Block::make({outer_for1, outer_for2}); - - Tensor A = Tensor(a_buf.node(), body); - - LoopNest l({A}); - l.prepareForCodegen(); - - std::vector a_data(50, 0); - - StmtPtr s = IRSimplifier::simplify(l.root_stmt()); - SimpleIREvaluator cg(s, {A}); - - std::vector a_ref(50, 0); - - for (int i = 0; i < 10; ++i) { - for (int j = 0; j < 5; ++j) { - a_ref[i * 5 + j] = (i * j) + i + j; - } - } - cg.call({a_data}); - - assertAllEqual(a_data, a_ref); -} - -TEST(LoopNest, InlineConstantIndex) { - const int N = 10; - BufHandle x_buf("a", {1, N, 1}, kFloat); - Tensor y = Compute( - "f", - {1, N, 1}, - [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& o) { - return x_buf.load(m, n, o); - }); - Tensor z = Compute( - "f", - {1, N, 1}, - [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& o) { - return y.load(m, n, o); - }); - - LoopNest l({z}, {y, z}); - l.simplify(); - ASSERT_TRUE(l.computeInline(y.buf())); -} - -TEST(LoopNest, CompoundTensorUsed) { - BufHandle a_buf("A", {10, 5}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j)}); - auto inner_for1 = For::make(j, 0, 5, for_body1); - auto outer_for1 = For::make(i, 0, 10, inner_for1); - auto for_body2 = Block::make( - {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)}); - auto inner_for2 = For::make(y, 0, 5, for_body2); - auto outer_for2 = For::make(x, 0, 10, inner_for2); - BlockPtr body = Block::make({outer_for1, outer_for2}); - - Tensor A = Tensor(a_buf.node(), body); - Tensor B = Compute("B", {10, 3}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i, j + 1) + A.load(i, j + 2); - }); - - LoopNest l({B}, {A, B}); - ASSERT_FALSE(l.computeInline(A.buf())); - l.prepareForCodegen(); - - std::vector a_data(50, 0); - std::vector b_data(50, 0); - - StmtPtr s = IRSimplifier::simplify(l.root_stmt()); - SimpleIREvaluator cg(s, {B}); - - std::vector b_ref(50, 0); - - auto AT = [](int i, int j) { return i * j + i + j; }; - for (int i = 0; i < 10; ++i) { - for (int j = 0; j < 3; ++j) { - b_ref[i * 3 + j] = AT(i, j + 1) + AT(i, j + 2); - } - } - cg.call({b_data}); - - assertAllEqual(b_data, b_ref); -} - -TEST(LoopNest, InlineFromLoad) { - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto store_a = For::make(i, 0, N, Store::make(a, {i}, i)); - auto store_b = For::make(j, 0, N, Store::make(b, {j}, Load::make(a, {j}))); - LoopNest l(Block::make({store_a, store_b}), {b.node()}); - - l.computeInline(a.node()); - - // Check that A[j] is replaced with j after inlining - std::ostringstream oss; - oss << *l.root_stmt(); - torch::jit::testing::FileCheck().run( - R"IR( -# CHECK: for (int j -# CHECK-NOT: B[j] = A[j] -# CHECK-NEXT: B[j] = j -)IR", - oss.str()); -} - -TEST(LoopNest, OptimizeConditionalsSimple) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5]) - // } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle a_buf("A", {20}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle b_buf("B", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle c_buf("C", {15}, kInt); - VarHandle i("i", kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto store = Store::make( - a_buf, - {i}, - IfThenElse::make( - CompareSelect::make(i, 5, kLT), - Load::make(b_buf, {i}), - Load::make(c_buf, {i - 5}))); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto forI = For::make(i, 0, 20, store); - auto par = Block::make({forI}); - - LoopNest nest(par, {a_buf.node()}); - nest.optimizeConditionals(); - - std::ostringstream oss; - oss << *nest.root_stmt(); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i = 0; i < 5 -# CHECK-NEXT: A[i] = B[i] -# CHECK: for (int i = 0; i < 15 -# CHECK-NEXT: A[i + 5] = C[i] - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(LoopNest, OptimizeConditionalsNestedConditions) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = IfThenElse(i<10, IfThenElse(i<5, B[i], C[i-5]), D[i-10]) - // } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle a_buf("A", {20}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle b_buf("B", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle c_buf("C", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle d_buf("D", {10}, kInt); - VarHandle i("i", kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto store = Store::make( - a_buf, - {i}, - IfThenElse::make( - CompareSelect::make(i, 10, kLT), - IfThenElse::make( - CompareSelect::make(i, 5, kLT), - Load::make(b_buf, {i}), - Load::make(c_buf, {i - 5})), - Load::make(d_buf, {i - 10}))); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto forI = For::make(i, 0, 20, store); - auto par = Block::make({forI}); - - LoopNest nest(par, {a_buf.node()}); - nest.optimizeConditionals(); - - std::ostringstream oss; - oss << *nest.root_stmt(); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i = 0; i < 5 -# CHECK-NEXT: A[i] = B[i] -# CHECK: for (int i = 0; i < 5 -# CHECK-NEXT: A[i + 5] = C[i] -# CHECK: for (int i = 0; i < 10 -# CHECK-NEXT: A[i + 10] = D[i] - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(LoopNest, OptimizeConditionalsMultipleStores) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5]) - // } - // for (int j = 0; j < 100; j++) { - // B[j] = IfThenElse(j<30 ? 1 : 0, C[j], D[j]) - // } - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle a_buf("A", {20}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle b_buf("B", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle c_buf("C", {100}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle d_buf("D", {100}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto storeA = Store::make( - a_buf, - {i}, - IfThenElse::make( - CompareSelect::make(i, 5, kLT), - Load::make(b_buf, {i}), - Load::make(c_buf, {i - 5}))); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto forI = For::make(i, 0, 20, storeA); - auto storeB = Store::make( - b_buf, - {j}, - IfThenElse::make( - CompareSelect::make(j, 30, kLT), - Load::make(c_buf, {j}), - Load::make(d_buf, {j}))); - auto forJ = For::make(j, 0, 100, storeB); - auto par = Block::make({forI, forJ}); - - LoopNest nest(par, {a_buf.node()}); - nest.optimizeConditionals(); - - std::ostringstream oss; - oss << *nest.root_stmt(); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i = 0; i < 5 -# CHECK-NEXT: A[i] = B[i] -# CHECK: for (int i = 0; i < 15 -# CHECK-NEXT: A[i + 5] = C[i] -# CHECK: for (int j = 0; j < 30 -# CHECK-NEXT: B[j] = C[j] -# CHECK: for (int j = 0; j < 70 -# CHECK-NEXT: B[j + 30] = D[j + 30] - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(LoopNest, OptimizeConditionalsMultipleStoresInOneLoop) { - // Input IR: - // for (int i = 0; i < 50; i++) { - // A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5]) - // B[j] = IfThenElse(j<30 ? 1 : 0, C[j], D[j]) - // } - // Only the first conditional, in the write to A, will be optimized. - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle a_buf("A", {100}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle b_buf("B", {100}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle c_buf("C", {100}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle d_buf("D", {100}, kInt); - VarHandle i("i", kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto storeA = Store::make( - a_buf, - {i}, - IfThenElse::make( - CompareSelect::make(i, 5, kLT), - Load::make(b_buf, {i}), - Load::make(c_buf, {i - 5}))); - auto storeB = Store::make( - b_buf, - {i}, - IfThenElse::make( - CompareSelect::make(i, 30, kLT), - Load::make(c_buf, {i}), - Load::make(d_buf, {i}))); - auto forI = For::make(i, 0, 50, Block::make({storeA, storeB})); - auto par = Block::make({forI}); - - LoopNest nest(par, {a_buf.node()}); - nest.optimizeConditionals(); - - std::ostringstream oss; - oss << *nest.root_stmt(); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i = 0; i < 5 -# CHECK-NEXT: A[i] = B[i] -# CHECK-NEXT: B[i] = C[i] -# CHECK: for (int i = 0; i < 45 -# CHECK-NEXT: A[i + 5] = C[i] -# CHECK-NEXT: B[i + 5] = IfThenElse(i + 5<30 ? 1 : 0, C[i + 5], D[i + 5]) - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(LoopNest, OptimizeConditionalsOuterLoopVar) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 100; j++) { - // A[i] = IfThenElse(i<10, IfThenElse(i<5, B[i], C[i-5]), D[i-10]) - // } - // } - // Currently, this case where the condition variable `i` is not the - // inner-most loop variable, is not optimized. - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle a_buf("A", {20}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle b_buf("B", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle c_buf("C", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle d_buf("D", {10}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto store = Store::make( - a_buf, - {i}, - IfThenElse::make( - CompareSelect::make(i, 10, kLT), - IfThenElse::make( - CompareSelect::make(i, 5, kLT), - Load::make(b_buf, {i}), - Load::make(c_buf, {i - 5})), - Load::make(d_buf, {i - 10}))); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto forI = For::make(i, 0, 20, For::make(j, 0, 100, store)); - auto par = Block::make({forI}); - LoopNest nest(par, {a_buf.node()}); - - HashProvider hasher; - auto hash_before = hasher.hash(nest.root_stmt()); - nest.optimizeConditionals(); - auto hash_after = hasher.hash(nest.root_stmt()); - ASSERT_EQ(hash_before, hash_after); -} - -TEST(LoopNest, OptimizeConditionalsCompValuesNotOrdered) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = IfThenElse(i<5, IfThenElse(i<10, B[i], C[i-5]), D[i-10]) - // } - // No optimization should be done here because one of the conditions use '>'. - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle a_buf("A", {20}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle b_buf("B", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle c_buf("C", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle d_buf("D", {10}, kInt); - VarHandle i("i", kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto store = Store::make( - a_buf, - {i}, - IfThenElse::make( - CompareSelect::make(i, 5, kLT), - IfThenElse::make( - CompareSelect::make(i, 10, kLT), - Load::make(b_buf, {i}), - Load::make(c_buf, {i - 5})), - Load::make(d_buf, {i - 10}))); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto forI = For::make(i, 0, 20, store); - auto par = Block::make({forI}); - LoopNest nest(par, {a_buf.node()}); - - HashProvider hasher; - auto hash_before = hasher.hash(nest.root_stmt()); - nest.optimizeConditionals(); - auto hash_after = hasher.hash(nest.root_stmt()); - ASSERT_EQ(hash_before, hash_after); -} - -TEST(LoopNest, OptimizeConditionalsCompValuesNotConstants) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = IfThenElse(i'. - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle a_buf("A", {20}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle b_buf("B", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle c_buf("C", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle d_buf("D", {10}, kInt); - VarHandle i("i", kInt); - VarHandle N("N", kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto store = Store::make( - a_buf, - {i}, - IfThenElse::make( - CompareSelect::make(i, N, kLT), - IfThenElse::make( - CompareSelect::make(i, 5, kLT), - Load::make(b_buf, {i}), - Load::make(c_buf, {i - 5})), - Load::make(d_buf, {i - 10}))); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto forI = For::make(i, 0, 20, store); - auto par = Block::make({forI}); - LoopNest nest(par, {a_buf.node()}); - - HashProvider hasher; - auto hash_before = hasher.hash(nest.root_stmt()); - nest.optimizeConditionals(); - auto hash_after = hasher.hash(nest.root_stmt()); - ASSERT_EQ(hash_before, hash_after); -} - -TEST(LoopNest, OptimizeConditionalsInvalidCondition) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = IfThenElse(i<10, IfThenElse(i>5, B[i], C[i-5]), D[i-10]) - // } - // No optimization should be done here because one of the conditions use '>'. - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle a_buf("A", {20}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle b_buf("B", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle c_buf("C", {5}, kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - BufHandle d_buf("D", {10}, kInt); - VarHandle i("i", kInt); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto store = Store::make( - a_buf, - {i}, - IfThenElse::make( - CompareSelect::make(i, 10, kLT), - IfThenElse::make( - CompareSelect::make(i, 5, kGT), - Load::make(b_buf, {i}), - Load::make(c_buf, {i - 5})), - Load::make(d_buf, {i - 10}))); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - auto forI = For::make(i, 0, 20, store); - auto par = Block::make({forI}); - LoopNest nest(par, {a_buf.node()}); - - HashProvider hasher; - auto hash_before = hasher.hash(nest.root_stmt()); - nest.optimizeConditionals(); - auto hash_after = hasher.hash(nest.root_stmt()); - ASSERT_EQ(hash_before, hash_after); -} - -TEST(LoopNest, OptimizeConditionalsInvalidCondition2) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = IfThenElse(10 colReduce(int M, int N) { - BufHandle a("a", {M, N}, kFloat); - Tensor t = Reduce( - "b", - {N}, - Sum(), - [&](const VarHandle& n, const VarHandle& m) { return a.load(m, n); }, - {M}); - return {a, Tensor(t.buf(), LoopNest::sanitizeNames(t.stmt()))}; -} - -static StmtPtr splitTailReorder(Tensor b) { - constexpr int kVectorWidth = 8; - LoopNest nest({b}); - auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[0]; - nest.splitWithTail(loops[0], kVectorWidth); - // Now the loopnests will look like: - // - // for (int i_outer = 0; ... - // for (int i_inner = 0; ... - // b[i_outer * 8 + i_inner] = float(0); - // for (int j = 0; ... - // b[i_outer * 8 + i_inner] = ReduceOp(...); - // - // for (int i_tail = 0; ... - // b[i_tail + ((100 - 0) / 8) * 8] = float(0); - // for (int j = 0; ... - // b[i_tail + ((100 - 0) / 8) * 8] = ReduceOp(...); - // - // Since there are 4 writes to b, we will get 4 loopnests from the - // call to `getAllLoopNestsWritingToBuf` below. - // - // Write #2: "b[i_outer * 8 + i_inner] = ReduceOp(...)" - // Loopnest #2: {i_outer, i_inner, j}; - // We will have to reorder i_inner and j. - auto loopnests = nest.getAllLoopNestsWritingToBuf(b.buf()); - LoopNest::reorderAxis(loopnests[1][1], loopnests[1][2]); - nest.prepareForCodegen(); - return nest.root_stmt(); -} - -static StmtPtr splitMaskReorder(Tensor b) { - constexpr int kVectorWidth = 8; - LoopNest nest({b}); - auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[1]; - nest.splitWithMask(loops[0], kVectorWidth); - loops = nest.getAllLoopNestsWritingToBuf(b.buf())[1]; - LoopNest::reorderAxis(loops[1], loops[2]); - nest.prepareForCodegen(); - return nest.root_stmt(); -} - -static void checkColReduce(StmtPtr s, BufHandle p, Tensor t) { - int M = immediateAs(p.dim(0)); - int N = immediateAs(p.dim(1)); - PaddedBuffer a(M, N); - PaddedBuffer b(N); - PaddedBuffer ref(N); - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - a(i, j) = 1.0f; - } - } - for (int i = 0; i < N; i++) { - b(i) = 0.0f; - } - for (int i = 0; i < N; i++) { - ref(i) = 76.0f; - } - SimpleIREvaluator(s, {p, t}).call({a, b}); - ExpectAllNear(b, ref, 1e-5); -} - -TEST(LoopNest, ColReduceSplitTailEvenReorder) { - constexpr int M = 76, N = 128; - auto p = colReduce(M, N); - StmtPtr s = splitTailReorder(p.second); - - std::ostringstream oss; - oss << *s; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i_outer -# CHECK-NEXT: for (int i_inner -# CHECK-NEXT: b[ -# CHECK: for (int j -# CHECK-NEXT: for (int i_inner -# CHECK-NEXT: b[ -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - checkColReduce(s, p.first, p.second); -} - -TEST(LoopNest, ColReduceSplitTailUnevenReorder) { - constexpr int M = 76, N = 100; - auto p = colReduce(M, N); - StmtPtr s = splitTailReorder(p.second); - - std::ostringstream oss; - oss << *s; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i_outer -# CHECK-NEXT: for (int i_inner -# CHECK-NEXT: b[ -# CHECK: for (int j -# CHECK-NEXT: for (int i_inner -# CHECK-NEXT: b[ -# CHECK: for (int i_tail -# CHECK-NEXT: b[ -# CHECK-NEXT: for (int j -# CHECK-NEXT: b[ - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - checkColReduce(s, p.first, p.second); -} - -TEST(LoopNest, ColReduceSplitMaskEvenReorder) { - constexpr int M = 76, N = 128; - auto p = colReduce(M, N); - StmtPtr s = splitMaskReorder(p.second); - checkColReduce(s, p.first, p.second); -} - -TEST(LoopNest, ColReduceSplitMaskUnevenReorder) { - constexpr int M = 76, N = 100; - auto p = colReduce(M, N); - StmtPtr s = splitMaskReorder(p.second); - checkColReduce(s, p.first, p.second); -} - -TEST(LoopNest, ReorderAxisWithMultipleConds) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // if i > 5 { - // if i < 10 { - // for (int j = 0; j < 100; j++) { - // A[i] = i * j; - // } - // } - // } - // } - BufHandle a_buf("A", {20}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {i}, Mul::make(i, j))); - auto inner_cond = Cond::make(CompareSelect::make(i, 10, kLT), forJ, nullptr); - auto outer_cond = - Cond::make(CompareSelect::make(i, 5, kGT), inner_cond, nullptr); - auto forI = For::make(i, 0, 20, outer_cond); - StmtPtr par = Block::make({forI}); - LoopNest l(par, {a_buf.node()}); - LoopNest::reorderAxis(forI, forJ); - ASSERT_EQ(par, l.root_stmt()); - par = IRSimplifier::simplify(par); - - const std::string& verification_pattern = - R"IR( -# CHECK: for (int j -# CHECK-NEXT: for (int i -# CHECK-NEXT: if (i>5 -# CHECK-NEXT: if (i<10 -# CHECK-NEXT: A[i] = i * j -# CHECK-NOT: for ( - )IR"; - std::ostringstream oss; - oss << *par; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(LoopNest, VectorizeUse) { - constexpr int N = 8; - BufHandle a("a", {N}, kFloat); - Tensor b = - Compute("b", {N}, [&](const VarHandle& n) { return a.load(n) + 1.0f; }); - Tensor c = - Compute("c", {N}, [&](const VarHandle& n) { return b.load(n) + 2.0f; }); - LoopNest nest({c}, {b, c}); - auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[0]; - ASSERT_TRUE(LoopNest::vectorize(loops[0])); - loops = nest.getAllLoopNestsWritingToBuf(c.buf())[0]; - ASSERT_TRUE(LoopNest::vectorize(loops[0])); - nest.prepareForCodegen(); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - StmtPtr s = nest.root_stmt(); - std::ostringstream oss; - oss << *nest.root_stmt(); - torch::jit::testing::FileCheck().run( - R"IR( -# CHECK: c[Ramp -)IR", - oss.str()); -} - -const char* int64Loop = R"IR( -# CHECK: for (int64_t i = 0ll; i < 12ll; i++) { -# CHECK: b[i] = (a[i]) + 1ll; -# CHECK: } -)IR"; - -TEST(LoopNest, Int64Direct) { - constexpr int64_t N = 12; - BufHandle a("a", {N}, kLong); - BufHandle b("b", {N}, kLong); - VarHandle n("i", kLong); - StmtPtr s = For::make( - n, LongImm::make(0l), N, b.store({n}, a.load({n}) + LongImm::make(1l))); - s = IRSimplifier::simplify(s); - std::ostringstream oss; - oss << *s; - torch::jit::testing::FileCheck().run(int64Loop, oss.str()); -} - -TEST(LoopNest, Int64Compute) { - constexpr int64_t N = 12; - BufHandle a("a", {N}, kLong); - Tensor b = Compute("b", {N}, [&](const VarHandle& n) { - return a.load(n) + LongImm::make(1l); - }); - LoopNest nest({b}); - nest.prepareForCodegen(); - nest.simplify(); - std::ostringstream oss; - oss << *nest.root_stmt(); - torch::jit::testing::FileCheck().run(int64Loop, oss.str()); -} - -TEST(LoopNest, DistributeLoopWithAllStmtsAsPivots) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = 0; - // for (int j = 0; j < 100; j++) { - // A[i] = A[i] + i * j; - // } - // B[i] = A[i]; - // for (int k = 0; k < 50; k++) { - // B[i] = B[i] + i * k; - // } - // } - BufHandle a_buf("A", {20}, kInt); - BufHandle b_buf("B", {20}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto initA = Store::make(a_buf, {i}, 0); - auto forJ = For::make( - j, - 0, - 100, - Store::make( - a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j)))); - auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i})); - auto forK = For::make( - k, - 0, - 50, - Store::make( - b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k)))); - auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); - auto par = Block::make({forI}); - - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = 0 -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i] = -# CHECK: for (int i -# CHECK-NEXT: B[i] = A[i] -# CHECK: for (int i -# CHECK-NEXT: for (int k -# CHECK-NEXT: B[i] = -# CHECK-NOT: for ( - )IR"; - - LoopNest nest(par, {a_buf.node(), b_buf.node()}); - auto new_loops = LoopNest::distributeLoop(forI, {initA, forJ, initB}); - - std::ostringstream oss; - oss << *par; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The first loop after distribution must be same as the original For. - ASSERT_EQ(new_loops.front(), forI); -} - -TEST(LoopNest, DistributeLoopWithOneStmtAsPivot) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = 0; - // for (int j = 0; j < 100; j++) { - // A[i] = A[i] + i * j; - // } - // B[i] = A[i]; - // for (int k = 0; k < 50; k++) { - // B[i] = B[i] + i * k; - // } - // } - BufHandle a_buf("A", {20}, kInt); - BufHandle b_buf("B", {20}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto initA = Store::make(a_buf, {i}, 0); - auto forJ = For::make( - j, - 0, - 100, - Store::make( - a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j)))); - auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i})); - auto forK = For::make( - k, - 0, - 50, - Store::make( - b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k)))); - auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); - auto par = Block::make({forI}); - - LoopNest nest(par, {a_buf.node(), b_buf.node()}); - auto new_loops = LoopNest::distributeLoop(forI, {forJ}); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = 0 -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i] = -# CHECK: for (int i -# CHECK-NEXT: B[i] = A[i] -# CHECK-NEXT: for (int k -# CHECK-NEXT: B[i] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The first loop after distribution must be same as the original For. - ASSERT_EQ(new_loops.front(), forI); -} - -TEST(LoopNest, DistributeLoopWithoutAnyPivot) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = 0; - // for (int j = 0; j < 100; j++) { - // A[i] = A[i] + i * j; - // } - // B[i] = A[i]; - // for (int k = 0; k < 50; k++) { - // B[i] = B[i] + i * k; - // } - // } - BufHandle a_buf("A", {20}, kInt); - BufHandle b_buf("B", {20}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto initA = Store::make(a_buf, {i}, 0); - auto forJ = For::make( - j, - 0, - 100, - Store::make( - a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j)))); - auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i})); - auto forK = For::make( - k, - 0, - 50, - Store::make( - b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k)))); - auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); - auto par = Block::make({forI}); - - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = 0 -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i] = -# CHECK: for (int i -# CHECK-NEXT: B[i] = A[i] -# CHECK: for (int i -# CHECK-NEXT: for (int k -# CHECK-NEXT: B[i] = -# CHECK-NOT: for ( - )IR"; - - LoopNest nest(par, {a_buf.node(), b_buf.node()}); - auto new_loops = LoopNest::distributeLoop(forI); - - std::ostringstream oss; - oss << *par; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The first loop after distribution must be same as the original For. - ASSERT_EQ(new_loops.front(), forI); -} - -TEST(LoopNest, DistributeLoopOverInnerLoops) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = 0; - // for (int j = 0; j < 100; j++) { - // A[i] = A[i] + i * j; - // } - // B[i] = A[i]; - // for (int k = 0; k < 50; k++) { - // B[i] = B[i] + i * k; - // } - // } - BufHandle a_buf("A", {20}, kInt); - BufHandle b_buf("B", {20}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto initA = Store::make(a_buf, {i}, 0); - auto forJ = For::make( - j, - 0, - 100, - Store::make( - a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j)))); - auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i})); - auto forK = For::make( - k, - 0, - 50, - Store::make( - b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k)))); - auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); - auto par = Block::make({forI}); - - LoopNest nest(par, {a_buf.node(), b_buf.node()}); - auto new_loops = LoopNest::distributeLoopOverInnerLoops(forI); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = 0 -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i] = -# CHECK: for (int i -# CHECK-NEXT: B[i] = A[i] -# CHECK-NEXT: for (int k -# CHECK-NEXT: B[i] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The first loop after distribution must be same as the original For. - ASSERT_EQ(new_loops.front(), forI); -} - -TEST(LoopNest, DistributeLoopAndParentsWithoutAnyPivot) { - // Input IR: - // for (int m = 0; m < 50; m++) { - // for (int i = 0; i < 20; i++) { - // A[m,i] = 0; - // for (int j = 0; j < 100; j++) { - // A[m,i] = A[m,i] + i * j; - // } - // B[m,i] = A[m,i]; - // for (int k = 0; k < 50; k++) { - // B[m,i] = B[m,i] + i * k; - // } - // } - // } - BufHandle a_buf("A", {100, 100}, kInt); - BufHandle b_buf("B", {100, 100}, kInt); - VarHandle m("m", kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto initA = Store::make(a_buf, {m, i}, 0); - auto forJ = For::make( - j, - 0, - 100, - Store::make( - a_buf, - {m, i}, - Add::make(Load::make(a_buf, {m, i}), Mul::make(i, j)))); - auto initB = Store::make(b_buf, {m, i}, Load::make(a_buf, {m, i})); - auto forK = For::make( - k, - 0, - 50, - Store::make( - b_buf, - {m, i}, - Add::make(Load::make(b_buf, {m, i}), Mul::make(i, k)))); - auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); - - { - // Check the case of distributing loop and its parents over all the - // statements in the loop. - const std::string& verification_pattern = - R"IR( -# CHECK: for (int m -# CHECK-NEXT: for (int i -# CHECK-NEXT: A[m, i] = 0 -# CHECK: for (int m -# CHECK-NEXT: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[m, i] = -# CHECK: for (int m -# CHECK-NEXT: for (int i -# CHECK-NEXT: B[m, i] = A[m, i] -# CHECK: for (int m -# CHECK-NEXT: for (int i -# CHECK-NEXT: for (int k -# CHECK-NEXT: B[m, i] = -# CHECK-NOT: for ( - )IR"; - - auto newForI = to(Stmt::clone(forI)); - auto forM = For::make(m, 0, 50, newForI); - auto par = Block::make({forM}); - LoopNest nest(par, {a_buf.node(), b_buf.node()}); - auto newLoops = LoopNest::distributeLoopAndParents(newForI); - - std::ostringstream oss; - oss << *par; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The first loop after distribution must be same as the original For. - ASSERT_EQ(newLoops.front(), forM); - } - - { - // Check the case of distributing loop and its parents over all the inner - // loops. - const std::string& verification_pattern = - R"IR( -# CHECK: for (int m -# CHECK-NEXT: for (int i -# CHECK-NEXT: A[m, i] = 0 -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[m, i] = -# CHECK: for (int m -# CHECK-NEXT: for (int i -# CHECK-NEXT: B[m, i] = A[m, i] -# CHECK-NEXT: for (int k -# CHECK-NEXT: B[m, i] = -# CHECK-NOT: for ( - )IR"; - - auto newForI = to(Stmt::clone(forI)); - auto forM = For::make(m, 0, 50, newForI); - auto par = Block::make({forM}); - LoopNest nest(par, {a_buf.node(), b_buf.node()}); - auto newLoops = LoopNest::distributeLoopAndParentsOverInnerLoops(newForI); - - std::ostringstream oss; - oss << *par; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The first loop after distribution must be same as the original For. - ASSERT_EQ(newLoops.front(), forM); - } -} - -TEST(LoopNest, fuseLoopsSimple) { - // Input IR: - // for (int j = 0; j < 100; j++) { - // A[j] = 10 * j; - // } - // for (int k = 0; k < 100; k++) { - // B[k] = 20 * k; - // } - BufHandle a_buf("A", {100}, kInt); - BufHandle b_buf("B", {100}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = For::make(k, 0, 100, Store::make(b_buf, {k}, Mul::make(20, k))); - auto par = Block::make({forJ, forK}); - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int j -# CHECK-NEXT: A[j] = -# CHECK-NEXT: B[j] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forJ); -} - -TEST(LoopNest, fuseLoopsMultiple) { - // Input IR: - // for (int i = 0; i < 100; i++) { - // A[i+100] = 20 + i; - // } - // for (int j = 0; j < 100; j++) { - // A[j] = 10 * j; - // } - // for (int k = 0; k < 100; k++) { - // B[k] = 20 * k; - // } - BufHandle a_buf("A", {200}, kInt); - BufHandle b_buf("B", {100}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forI = - For::make(i, 0, 100, Store::make(a_buf, {i + 100}, Add::make(20, i))); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = For::make(k, 0, 100, Store::make(b_buf, {k}, Mul::make(20, k))); - auto par = Block::make({forI, forJ, forK}); - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forI, forJ, forK}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i + 100] = -# CHECK-NEXT: A[i] = -# CHECK-NEXT: B[i] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forI); -} - -TEST(LoopNest, fuseLoopsNested) { - // Input IR: - // for (int m = 0; m < 20; m++) { - // A[m] = 0; - // for (int j = 0; j < 100; j++) { - // A[m] = A[m] + m * j; - // } - // } - // for (int n = 0; n < 20; n++) { - // B[n] = A[n]; - // for (int k = 0; k < 50; k++) { - // B[n] = B[n] + n * k; - // } - // } - BufHandle a_buf("A", {20, 100}, kInt); - BufHandle b_buf("B", {20, 100}, kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto initA = Store::make(a_buf, {m}, 0); - auto forJ = For::make( - j, - 0, - 100, - Store::make( - a_buf, {m}, Add::make(Load::make(a_buf, {m}), Mul::make(m, j)))); - auto initB = Store::make(b_buf, {n}, Load::make(a_buf, {n})); - auto forK = For::make( - k, - 0, - 50, - Store::make( - b_buf, {n}, Add::make(Load::make(b_buf, {n}), Mul::make(n, k)))); - auto forM = For::make(m, 0, 20, Block::make({initA, forJ})); - auto forN = For::make(n, 0, 20, Block::make({initB, forK})); - auto par = Block::make({forM, forN}); - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forM, forN}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int m -# CHECK-NEXT: A[m] = 0 -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[m] = -# CHECK: B[m] = A[m] -# CHECK-NEXT: for (int k -# CHECK-NEXT: B[m] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forM); -} - -TEST(LoopNest, fuseLoopsNested2D) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 100; j++) { - // A[i,j] = i * j * 500; - // } - // } - // for (int m = 0; m < 20; m++) { - // for (int n = 0; n < 50; n++) { - // B[m,n] = m + n * 100; - // } - // } - BufHandle a_buf("A", {20, 100}, kInt); - BufHandle b_buf("B", {20, 100}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - auto forI = For::make( - i, - 0, - 20, - For::make( - j, - 0, - 100, - Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)))); - auto forM = For::make( - m, - 0, - 20, - For::make( - n, - 0, - 50, - Store::make(b_buf, {m, n}, Add::make(m, Mul::make(n, 100))))); - auto par = Block::make({forI, forM}); - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i, j] = -# CHECK: for (int n -# CHECK-NEXT: B[i, n] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forI); -} - -TEST(LoopNest, fuseLoopsNested2DInner) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 100; j++) { - // A[i,j] = i * j * 500; - // } - // for (int n = 0; n < 100; n++) { - // B[i,n] = m + n * 100; - // } - // } - BufHandle a_buf("A", {20, 100}, kInt); - BufHandle b_buf("B", {20, 100}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle n("n", kInt); - auto forJ = For::make( - j, 0, 100, Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500))); - auto forN = For::make( - n, 0, 100, Store::make(b_buf, {i, n}, Add::make(i, Mul::make(n, 100)))); - auto forI = For::make(i, 0, 20, Block::make({forJ, forN})); - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forJ, forN}, &fused_loop)); - - std::ostringstream oss; - oss << *forI; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i, j] = -# CHECK-NEXT: B[i, j] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forJ); -} - -TEST(LoopNest, fuseLoopsDifferentStopBounds) { - // Input IR: - // for (int j = 0; j < 100; j++) { - // A[j] = 10 * j; - // } - // for (int k = 0; k < 50; k++) { - // B[k] = 20 * k; - // } - BufHandle a_buf("A", {100}, kInt); - BufHandle b_buf("B", {100}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = For::make(k, 0, 50, Store::make(b_buf, {j}, Mul::make(20, k))); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forJ, forK}); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsDifferentStartBounds) { - // Input IR: - // for (int j = 0; j < 100; j++) { - // A[j] = 10 * j; - // } - // for (int k = 50; k < 100; k++) { - // B[k] = 20 * k; - // } - BufHandle a_buf("A", {100}, kInt); - BufHandle b_buf("B", {100}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = For::make(k, 50, 100, Store::make(b_buf, {j}, Mul::make(20, k))); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forJ, forK}); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsNotContiguous) { - // Input IR: - // for (int j = 0; j < 100; j++) { - // A[j] = 10 * j; - // } - // B[0] = 0; - // for (int k = 0; k < 100; k++) { - // B[k] = 20 * k; - // } - BufHandle a_buf("A", {100}, kInt); - BufHandle b_buf("B", {100}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto initB = Store::make(b_buf, {0}, 0); - auto forK = For::make(k, 0, 100, Store::make(b_buf, {j}, Mul::make(20, k))); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forJ, initB, forK}); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsWithDifferentParents) { - // Input IR: - // for (int i = 0; i < 50; i++) { - // for (int j = 0; j < 100; j++) { - // A[i,j] = i * j; - // } - // } - // B[0] = 0; - // for (int k = 50; k < 100; k++) { - // B[k] = 20 * k; - // } - BufHandle a_buf("A", {50, 100}, kInt); - BufHandle b_buf("B", {100}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {i, j}, Mul::make(i, j))); - auto forI = For::make(i, 0, 50, forJ); - auto initB = Store::make(b_buf, {0}, 0); - auto forK = For::make(k, 50, 100, Store::make(b_buf, {j}, Mul::make(20, k))); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forI, initB, forK}); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsWithVariableBounds) { - // Input IR: - // for (int j = 0; j < N; j++) { - // A[j] = 10 * j; - // } - // for (int k = 0; k < N; k++) { - // B[k] = 20 * k; - // } - BufHandle a_buf("A", {20}, kInt); - BufHandle b_buf("B", {20}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - VarHandle N("N", kInt); - auto forJ = For::make(j, 0, N, Store::make(a_buf, {j}, Mul::make(10, j))); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers) - auto forK = For::make(k, 0, N, Store::make(b_buf, {j}, Mul::make(20, k))); - auto par = Block::make({forJ, forK}); - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int j -# CHECK-NEXT: A[j] = -# CHECK-NEXT: B[j] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forJ); -} - -TEST(LoopNest, fuseLoopsWithExprBounds) { - // Input IR: - // for (int j = 0; j < M + N; j++) { - // A[j] = 10 * j; - // } - // for (int k = 0; k < M + N; k++) { - // B[k] = 20 * k; - // } - BufHandle a_buf("A", {20}, kInt); - BufHandle b_buf("B", {20}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - VarHandle M("M", kInt); - VarHandle N("N", kInt); - auto forJ = For::make(j, 0, M + N, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = For::make(k, 0, M + N, Store::make(b_buf, {j}, Mul::make(20, k))); - auto par = Block::make({forJ, forK}); - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int j -# CHECK-NEXT: A[j] = -# CHECK-NEXT: B[j] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forJ); -} - -TEST(LoopNest, fuseLoopsWithDifferentExprBounds) { - // Input IR: - // for (int j = M; j < N * 2; j++) { - // A[j] = 10 * j; - // } - // for (int k = M; k < N + N; k++) { - // B[k] = 20 * k; - // } - BufHandle a_buf("A", {20}, kInt); - BufHandle b_buf("B", {20}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - VarHandle M("M", kInt); - VarHandle N("N", kInt); - auto forJ = For::make(j, M, N * 2, Store::make(a_buf, {j}, Mul::make(10, j))); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers) - auto forK = For::make(k, M, N + N, Store::make(b_buf, {j}, Mul::make(20, k))); - auto par = Block::make({forJ, forK}); - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int j -# CHECK-NEXT: A[j] = -# CHECK-NEXT: B[j] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forJ); -} - -TEST(LoopNest, fuseLoopsWithNonOverlappingBufferAccesses) { - // Input IR: - // for (int j = 10; j < 100; j++) { - // A[j] = 10 * j; - // } - // for (int k = 10; k < 100; k++) { - // A[k+100] = 30 * k - // } - BufHandle a_buf("A", {200}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = - For::make(k, 10, 100, Store::make(a_buf, {k + 100}, Mul::make(30, k))); - auto par = Block::make({forJ, forK}); - - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int j -# CHECK-NEXT: A[j] = -# CHECK-NEXT: A[j + 100] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forJ); -} - -TEST(LoopNest, fuseLoopsWithNonOverlapping2DBufferAccesses) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 100; j++) { - // A[i,j] = i * j * 500; - // } - // } - // for (int m = 0; m < 20; m++) { - // for (int n = 0; n < 50; n++) { - // A[m+20,n+100] = m + n * 100; - // } - // } - BufHandle a_buf("A", {20, 100}, kInt); - BufHandle b_buf("B", {20, 50}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)); - auto forJ = For::make(j, 0, 100, storeA1); - auto forI = For::make(i, 0, 20, forJ); - auto storeA2 = - Store::make(a_buf, {m + 20, n + 100}, Add::make(m, Mul::make(n, 100))); - auto forN = For::make(n, 0, 50, storeA2); - auto forM = For::make(m, 0, 20, forN); - auto par = Block::make({forI, forM}); - - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i, j] = -# CHECK: for (int n -# CHECK-NEXT: A[i + 20, n + 100] = -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forI); -} - -TEST(LoopNest, fuseLoopsWithReductions) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // A[i] = 0 - // for (int j = 0; j < 100; j++) { - // A[i] = A[i] + B[i,j]; - // } - // } - // for (int m = 0; m < 20; m++) { - // C[m] = A[m]; - // } - BufHandle a_buf("A", {20}, kInt); - BufHandle b_buf("B", {20, 100}, kInt); - BufHandle c_buf("C", {20}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle m("m", kInt); - auto initA = Store::make(a_buf, {i}, 0); - auto sumA = Store::make( - a_buf, {i}, Add::make(Load::make(a_buf, {i}), Load::make(b_buf, {i, j}))); - auto forJ = For::make(j, 0, 100, sumA); - auto forI = For::make(i, 0, 20, Block::make({initA, forJ})); - auto forM = - For::make(m, 0, 20, Store::make(c_buf, {m}, Load::make(a_buf, {m}))); - auto par = Block::make({forI, forM}); - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i] = (A[i]) + -# CHECK-NOT: for ( -# CHECK: C[i] = A[i] - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forI); -} - -TEST(LoopNest, fuseLoopsWith2DReductions) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 50; j++) { - // A[i,j] = 0 - // for (int k = 0; k < 100; k++) { - // A[i,j] = A[i,j] + B[i,j,k]; - // } - // } - // } - // for (int m = 0; m < 20; m++) { - // for (int n = 0; n < 40; n++) { - // C[m,n] = A[m,n]; - // } - // } - BufHandle a_buf("A", {20, 50}, kInt); - BufHandle b_buf("B", {20, 50, 100}, kInt); - BufHandle c_buf("C", {20, 40}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - auto initA = Store::make(a_buf, {i, j}, 0); - auto sumA = Store::make( - a_buf, - {i, j}, - Add::make(Load::make(a_buf, {i, j}), Load::make(b_buf, {i, j, k}))); - auto forK = For::make(k, 0, 100, sumA); - auto forJ = For::make(j, 0, 50, Block::make({initA, forK})); - auto forI = For::make(i, 0, 20, forJ); - auto storeC = Store::make(c_buf, {m, n}, Load::make(a_buf, {m, n})); - auto forM = For::make(m, 0, 20, For::make(n, 0, 40, storeC)); - auto par = Block::make({forI, forM}); - - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i, j] = -# CHECK-NEXT: for (int k -# CHECK-NEXT: A[i, j] = (A[i, j]) + -# CHECK: for (int n -# CHECK-NEXT: C[i, n] = A[i, n] -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forI); -} - -TEST(LoopNest, fuseLoopsWithComplexIndices) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 20; j++) { - // A[i,j*20+j+2] = i + j; - // } - // } - // for (int m = 0; m < 20; m++) { - // for (int n = 0; n < 20; n++) { - // B[m,n] = A[m,n*20+n+2]; - // } - // } - BufHandle a_buf("A", {20, 400}, kInt); - BufHandle b_buf("B", {20, 400}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - auto writeA = Store::make(a_buf, {i, j * 20 + j + 2}, i + j); - auto forI = For::make(i, 0, 20, For::make(j, 0, 20, writeA)); - auto storeB = - Store::make(b_buf, {m, n}, Load::make(a_buf, {m, n * 20 + n + 2})); - auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB)); - auto par = Block::make({forI, forM}); - - ForPtr fused_loop; - ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i, (j * 20 + j) + 2] = i + j -# CHECK: for (int n -# CHECK-NEXT: B[i, n] = A[i, (n * 20 + n) + 2] -# CHECK-NOT: for ( - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - // The fused loop must be the same as the first loop. - ASSERT_EQ(fused_loop, forI); -} - -TEST(LoopNest, fuseLoopsWithMixedLoopVarsAsIndices) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 20; j++) { - // A[i,i*20+j] = i + j; - // } - // } - // for (int m = 0; m < 20; m++) { - // for (int n = 0; n < 20; n++) { - // B[m,n] = A[m,m*20+n]; // Both indices of A use m - // } - // } - BufHandle a_buf("A", {20, 500}, kInt); - BufHandle b_buf("B", {20, 500}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - auto writeA = Store::make(a_buf, {i, i * 20 + j}, i + j); - auto forI = For::make(i, 0, 20, For::make(j, 0, 20, writeA)); - auto storeB = Store::make(b_buf, {m, n}, Load::make(a_buf, {m, m * 20 + n})); - auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB)); - auto par = Block::make({forI, forM}); - - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsWithTranspose) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 20; j++) { - // A[i,j] = i + j; - // } - // } - // for (int m = 0; m < 20; m++) { - // for (int n = 0; n < 20; n++) { - // B[m,n] = A[n,m]; // Transpose - // } - // } - BufHandle a_buf("A", {20, 20}, kInt); - BufHandle b_buf("B", {20, 20}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - auto writeA = Store::make(a_buf, {i, j}, i + j); - auto forI = For::make(i, 0, 20, For::make(j, 0, 20, writeA)); - auto storeB = Store::make(b_buf, {m, n}, Load::make(a_buf, {n, m})); - auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB)); - auto par = Block::make({forI, forM}); - - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsThatViolateDependencies1) { - // Input IR: - // for (int j = 10; j < 100; j++) { - // A[j] = 10 * j; - // } - // for (int k = 10; k < 100; k++) { - // A[k-1] = 20 * k; - // } - BufHandle a_buf("A", {100}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = - For::make(k, 10, 100, Store::make(a_buf, {k - 1}, Mul::make(20, k))); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forJ, forK}); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsThatViolateDependencies2) { - // Input IR: - // for (int j = 10; j < 100; j++) { - // A[j] = 10 * j; - // } - // for (int k = 10; k < 100; k++) { - // A[k+50] = 20 * k; - // } - BufHandle a_buf("A", {150}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = - For::make(k, 10, 100, Store::make(a_buf, {k + 50}, Mul::make(20, k))); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forJ, forK}); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsThatViolateDependencies3) { - // Input IR: - // for (int m = 0; m < 20; m++) { - // A[m] = 0; - // for (int j = 0; j < 100; j++) { - // A[m] = A[m] + m * j; - // } - // } - // for (int n = 0; n < 20; n++) { - // B[n] = A[n+1]; - // for (int k = 0; k < 50; k++) { - // B[n] = B[n] + n * k; - // } - // } - BufHandle a_buf("A", {25, 100}, kInt); - BufHandle b_buf("B", {20, 50}, kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto initA = Store::make(a_buf, {m}, 0); - auto forJ = For::make( - j, - 0, - 100, - Store::make( - a_buf, {m}, Add::make(Load::make(a_buf, {m}), Mul::make(m, j)))); - auto initB = Store::make(b_buf, {n}, Load::make(a_buf, {n + 1})); - auto forK = For::make( - k, - 0, - 50, - Store::make( - b_buf, {n}, Add::make(Load::make(b_buf, {n}), Mul::make(n, k)))); - auto forM = For::make(m, 0, 20, Block::make({initA, forJ})); - auto forN = For::make(n, 0, 20, Block::make({initB, forK})); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forM, forN}); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forM, forN}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsThatViolateDependencies4) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 100; j++) { - // A[i,j] = i * j * 500; - // } - // } - // for (int m = 0; m < 20; m++) { - // for (int n = 0; n < 50; n++) { - // A[m+1,n] = m + n * 100; - // } - // } - BufHandle a_buf("A", {30, 100}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - auto forI = For::make( - i, - 0, - 20, - For::make( - j, - 0, - 100, - Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)))); - auto forM = For::make( - m, - 0, - 20, - For::make( - n, - 0, - 50, - Store::make(a_buf, {m + 1, n}, Add::make(m, Mul::make(n, 100))))); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forI, forM}); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsThatViolateDependencies5) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 100; j++) { - // A[i,j] = i * j * 500; - // } - // for (int n = 0; n < 100; n++) { - // A[i,n+1] = m + n * 100; - // } - // } - BufHandle a_buf("A", {20, 200}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle n("n", kInt); - auto forJ = For::make( - j, 0, 100, Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500))); - auto forN = For::make( - n, - 0, - 100, - Store::make(a_buf, {i, n + 1}, Add::make(i, Mul::make(n, 100)))); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores,cppcoreguidelines-avoid-magic-numbers) - auto forI = For::make(i, 0, 20, Block::make({forJ, forN})); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forJ, forN}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsThatViolateDependencies6) { - // Input IR: - // for (int j = 0; j < 100; j++) { - // A[j] = 10 * j; - // } - // for (int k = 0; k < 100; k++) { - // B[k] = 20 * A[99-k]; - // } - BufHandle a_buf("A", {100}, kInt); - BufHandle b_buf("B", {100}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - auto forK = For::make( - k, - 0, - 100, - Store::make( - b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forJ, forK}); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); -} - -TEST(LoopNest, fuseLoopsThatViolateDependencies7) { - // Input IR: - // for (int k = 0; k < 100; k++) { - // B[k] = 20 * A[99-k]; - // } - // for (int j = 0; j < 100; j++) { - // A[j] = 10 * j; - // } - BufHandle a_buf("A", {100}, kInt); - BufHandle b_buf("B", {100}, kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto forK = For::make( - k, - 0, - 100, - Store::make( - b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); - auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forK, forJ}); - ForPtr fused_loop; - ASSERT_FALSE(LoopNest::fuseLoops({forK, forJ}, &fused_loop)); -} - -TEST(LoopNest, areLoopsPerfectlyNested) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 30; j++) { - // for (int k = 0; k < 40; k++) { - // A[i,j,k] = i * j * k; - // } - // } - // } - BufHandle a_buf("A", {20, 30, 40}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); - auto forK = For::make(k, 0, 40, store); - auto forJ = For::make(j, 0, 30, forK); - auto forI = For::make(i, 0, 20, forJ); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forI}); - ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); - - // Specifying the loops in any other order fails. - ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forJ, forI, forK})); - ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forI, forK, forJ})); - ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forK, forJ, forI})); - - // Adding a statement to forK body should be OK. - auto init = Store::make(a_buf, {i, j}, 0); - forK->body()->insert_stmt_before(init, store); - ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); - - // Adding a statement in forJ body should fail this test. - forK->body()->remove_stmt(init); - forJ->body()->insert_stmt_before(init, forK); - ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); - - // Similarly, adding a statement in forI body should fail this test. - forJ->body()->remove_stmt(init); - forI->body()->insert_stmt_before(init, forJ); - ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); -} - -TEST(LoopNest, reorderNestedLoops2D) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 30; j++) { - // A[i,j] = i * j; - // } - // } - BufHandle a_buf("A", {20, 30, 40}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto store = Store::make(a_buf, {i, j}, Mul::make(i, j)); - auto forJ = For::make(j, 0, 30, store); - auto forI = For::make(i, 0, 20, forJ); - auto par = Block::make({forI}); - - auto reordered = LoopNest::reorder({forI, forJ}, {1, 0}); - - ASSERT_EQ(reordered[0], forJ); - ASSERT_EQ(reordered[1], forI); - ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forJ, forI})); - ASSERT_EQ(forJ->get_parent(), par); - ASSERT_EQ(store->get_parent(), forI->body()); -} - -TEST(LoopNest, reorderNestedLoops3D) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 30; j++) { - // for (int k = 0; k < 40; k++) { - // A[i,j,k] = i * j * k; - // } - // } - // } - BufHandle a_buf("A", {20, 30, 40}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); - auto forK = For::make(k, 0, 40, store); - auto forJ = For::make(j, 0, 30, forK); - auto forI = For::make(i, 0, 20, forJ); - auto par = Block::make({forI}); - - auto reordered = LoopNest::reorder({forI, forJ, forK}, {2, 0, 1}); - - ASSERT_EQ(reordered[0], forK); - ASSERT_EQ(reordered[1], forI); - ASSERT_EQ(reordered[2], forJ); - ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forK, forI, forJ})); - ASSERT_EQ(forK->get_parent(), par); - ASSERT_EQ(store->get_parent(), forJ->body()); -} - -TEST(LoopNest, reorderNestedLoops4D) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 30; j++) { - // for (int k = 0; k < 40; k++) { - // for (int l = 0; l < 50; l++) { - // A[i,j,k,l] = i * j * k * l * 500; - // } - // } - // } - // } - BufHandle a_buf("A", {20, 30, 40, 50}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - VarHandle l("l", kInt); - auto store = Store::make( - a_buf, - {i, j, k, l}, - Mul::make(Mul::make(Mul::make(Mul::make(i, j), k), l), 500)); - auto forL = For::make(l, 0, 50, store); - auto forK = For::make(k, 0, 40, forL); - auto forJ = For::make(j, 0, 30, forK); - auto forI = For::make(i, 0, 20, forJ); - auto par = Block::make({forI}); - - auto reordered = LoopNest::reorder({forI, forJ, forK, forL}, {2, 0, 3, 1}); - - ASSERT_EQ(reordered[0], forK); - ASSERT_EQ(reordered[1], forI); - ASSERT_EQ(reordered[2], forL); - ASSERT_EQ(reordered[3], forJ); - ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forK, forI, forL, forJ})); - ASSERT_EQ(forK->get_parent(), par); - ASSERT_EQ(store->get_parent(), forJ->body()); -} - -TEST(LoopNest, reorderTrivialPermutation) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 30; j++) { - // for (int k = 0; k < 40; k++) { - // A[i,j,k] = i * j * k; - // } - // } - // } - BufHandle a_buf("A", {20, 30, 40}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); - auto forK = For::make(k, 0, 40, store); - auto forJ = For::make(j, 0, 30, forK); - auto forI = For::make(i, 0, 20, forJ); - auto par = Block::make({forI}); - - auto reordered = LoopNest::reorder({forI, forJ, forK}, {0, 1, 2}); - - ASSERT_EQ(reordered[0], forI); - ASSERT_EQ(reordered[1], forJ); - ASSERT_EQ(reordered[2], forK); - ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); - ASSERT_EQ(forI->get_parent(), par); - ASSERT_EQ(store->get_parent(), forK->body()); -} - -TEST(LoopNest, reorderInvalidPermutations) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 30; j++) { - // for (int k = 0; k < 40; k++) { - // A[i,j,k] = i * j * k; - // } - // } - // } - BufHandle a_buf("A", {20, 30, 40}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); - auto forK = For::make(k, 0, 40, store); - auto forJ = For::make(j, 0, 30, forK); - auto forI = For::make(i, 0, 20, forJ); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forI}); - - ASSERT_THROWS_WITH( - LoopNest::reorder({forI, forJ, forK}, {0, 1, 2, 3}), - "invalid permutation size"); - ASSERT_THROWS_WITH( - LoopNest::reorder({forI, forJ, forK}, {1, 2}), - "invalid permutation size"); - ASSERT_THROWS_WITH( - LoopNest::reorder({forI, forJ, forK}, {2, 1, 3}), - "invalid permutation for reorder"); - ASSERT_THROWS_WITH( - LoopNest::reorder({forI, forJ, forK}, {1, 1, 0}), - "invalid permutation for reorder"); - ASSERT_THROWS_WITH( - LoopNest::reorder({forI, forJ, forK}, {0, 0, 0}), - "invalid permutation for reorder"); -} - -TEST(LoopNest, reorderInvalidLoopNest) { - // Input IR: - // for (int i = 0; i < 20; i++) { - // for (int j = 0; j < 30; j++) { - // A[i,j] = 0 - // for (int k = 0; k < 40; k++) { - // A[i,j,k] = i * j * k; - // } - // } - // } - BufHandle a_buf("A", {20, 30, 40}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); - auto forK = For::make(k, 0, 40, store); - auto forJ = For::make(j, 0, 30, forK); - auto forI = For::make(i, 0, 20, forJ); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - auto par = Block::make({forI}); - - // Specifying the loops in incorrect order fails. - ASSERT_THROWS_WITH( - LoopNest::reorder({forK, forI, forJ}, {1, 0, 2}), - "reorder is only allowed on perfectly nested loops"); - - // Adding a statement to forJ loop fails. - auto init = Store::make(a_buf, {i}, 0); - forJ->body()->insert_stmt_before(init, forK); - ASSERT_THROWS_WITH( - LoopNest::reorder({forI, forJ, forK}, {1, 0, 2}), - "reorder is only allowed on perfectly nested loops"); - - // Moving that statement to forI loop also fails. - forJ->body()->remove_stmt(init); - forI->body()->insert_stmt_before(init, forJ); - ASSERT_THROWS_WITH( - LoopNest::reorder({forI, forJ, forK}, {1, 0, 2}), - "reorder is only allowed on perfectly nested loops"); -} - -TEST(LoopNest, compressBufferSimple) { - // Input IR: - // for (int i = 0; i < 100; ++i) { - // for (int j = 0; j < 200; ++j) { - // A[i,j] = sin(i*j) - // } - // for (int j = 0; j < 199; ++j) { - // B[i,j] = A[i,j] + A[i, j+1] - // } - // } - BufHandle aBuf("A", {100, 200}, kInt); - BufHandle bBuf("B", {100, 200}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {i, j}, sin(i * j))); - auto forJ2 = For::make( - j, - 0, - 199, - Store::make( - bBuf, - {i, j}, - Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j + 1})))); - auto forI = For::make(i, 0, 100, Block::make({forJ1, forJ2})); - auto par = Block::make({forI}); - LoopNest::compressBuffer(aBuf.node(), par); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[0, j] = -# CHECK: for (int j -# CHECK-NEXT: B[i, j] = (A[0, j]) + (A[0, j + 1]) - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - ASSERT_EQ(aBuf.node()->ndim(), 2); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); -} - -TEST(LoopNest, compressBufferMultipleDims) { - // Input IR: - // for (int i = 0; i < 100; ++i) { - // for (int j = 0; j < 200; ++j) { - // A[i,j] = sin(i*j) - // B[i,j] = A[i,j] + A[i,j] - // } - // } - BufHandle aBuf("A", {100, 200}, kInt); - BufHandle bBuf("B", {100, 200}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto store1 = Store::make(aBuf, {i, j}, sin(i * j)); - auto store2 = Store::make( - bBuf, - {i, j}, - Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j}))); - auto forJ = For::make(j, 0, 200, Block::make({store1, store2})); - auto forI = For::make(i, 0, 100, forJ); - auto par = Block::make({forI}); - LoopNest::compressBuffer(aBuf.node(), par); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[0, 0] = -# CHECK-NEXT: B[i, j] = (A[0, 0]) + (A[0, 0]) - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - ASSERT_EQ(aBuf.node()->ndim(), 2); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 1); -} - -TEST(LoopNest, compressBufferMultipleDims2) { - // Input IR: - // for (int i = 0; i < 100; ++i) { - // for (int j = 0; j < 200; ++j) { - // for (int k = 0; k < 300; ++k) { - // A[i,j,k] = sin(i*j*k) - // } - // for (int k = 0; k < 299; ++j) { - // B[i,j,k] = A[i,j,k] + A[i,j,k+1] - // } - // } - // } - BufHandle aBuf("A", {100, 200, 300}, kInt); - BufHandle bBuf("B", {100, 200, 300}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - auto store1 = Store::make(aBuf, {i, j, k}, sin(i * j * k)); - auto forK1 = For::make(k, 0, 300, store1); - auto store2 = Store::make( - bBuf, - {i, j, k}, - Add::make(Load::make(aBuf, {i, j, k}), Load::make(aBuf, {i, j, k + 1}))); - auto forK2 = For::make(k, 0, 299, store2); - auto forJ = For::make(j, 0, 200, Block::make({forK1, forK2})); - auto forI = For::make(i, 0, 100, forJ); - auto par = Block::make({forI}); - LoopNest::compressBuffer(aBuf.node(), par); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: for (int k -# CHECK-NEXT: A[0, 0, k] = -# CHECK: for (int k -# CHECK-NEXT: B[i, j, k] = (A[0, 0, k]) + (A[0, 0, k + 1]) - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - ASSERT_EQ(aBuf.node()->ndim(), 3); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 1); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(2), 300); -} - -TEST(LoopNest, compressBufferDifferentOrderIndices) { - // Input IR: - // for (int i = 0; i < 100; ++i) { - // for (int j = 0; j < 200; ++j) { - // A[j, i] = sin(i*j) - // } - // for (int j = 0; j < 99; ++j) { - // B[i, j] = A[j, i] + A[j+1, 0] - // } - // } - BufHandle aBuf("A", {100, 200}, kInt); - BufHandle bBuf("B", {100, 200}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {j, i}, sin(i * j))); - auto forJ2 = For::make( - j, - 0, - 99, - Store::make( - bBuf, - {i, j}, - Add::make(Load::make(aBuf, {j, i}), Load::make(aBuf, {j + 1, i})))); - auto forI = For::make(i, 0, 100, Block::make({forJ1, forJ2})); - auto par = Block::make({forI}); - LoopNest::compressBuffer(aBuf.node(), par); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[j, 0] = -# CHECK: for (int j -# CHECK-NEXT: B[i, j] = (A[j, 0]) + (A[j + 1, 0]) - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - ASSERT_EQ(aBuf.node()->ndim(), 2); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 100); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 1); -} - -TEST(LoopNest, compressBufferVariableBounds) { - // Input IR: - // for (int i = 0; i < M; ++i) { - // for (int j = 0; j < N; ++j) { - // A[i,j] = sin(i*j) - // } - // for (int j = 0; j < N-1; ++j) { - // B[i,j] = A[i,j] + A[i, j+1] - // } - // } - BufHandle aBuf("A", {100, 200}, kInt); - BufHandle bBuf("B", {100, 200}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle M("M", kInt); - VarHandle N("N", kInt); - auto forJ1 = For::make(j, 0, N, Store::make(aBuf, {i, j}, sin(i * j))); - auto forJ2 = For::make( - j, - 0, - N - 1, - Store::make( - bBuf, - {i, j}, - Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j + 1})))); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - auto forI = For::make(i, 0, M, Block::make({forJ1, forJ2})); - auto par = Block::make({forI}); - LoopNest::compressBuffer(aBuf.node(), par); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[0, j] = -# CHECK: for (int j -# CHECK-NEXT: B[i, j] = (A[0, j]) + (A[0, j + 1]) - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - ASSERT_EQ(aBuf.node()->ndim(), 2); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); -} - -TEST(LoopNest, compressBufferNoCommonParentLoops) { - // Input IR: - // for (int i = 0; i < 100; ++i) { - // for (int j = 0; j < 200; ++j) { - // A[i,j] = sin(i*j) - // } - // } - // for (int i = 0; i < 100; ++i) { - // for (int j = 0; j < 199; ++j) { - // B[i,j] = A[i,j] + A[i, j+1] - // } - // } - BufHandle aBuf("A", {100, 200}, kInt); - BufHandle bBuf("B", {100, 200}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {i, j}, sin(i * j))); - auto forJ2 = For::make( - j, - 0, - 199, - Store::make( - bBuf, - {i, j}, - Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j + 1})))); - auto forI1 = For::make(i, 0, 100, forJ1); - auto forI2 = For::make(i, 0, 100, forJ2); - auto par = Block::make({forI1, forI2}); - LoopNest::compressBuffer(aBuf.node(), par); - - // There should be no change in the buffer or code. - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i, j] = -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: B[i, j] = (A[i, j]) + (A[i, j + 1]) - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - ASSERT_EQ(aBuf.node()->ndim(), 2); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 100); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); -} - -TEST(LoopNest, compressBufferIndicesMixed) { - // Input IR: - // for (int i = 0; i < 100; ++i) { - // for (int j = 0; j < 200; ++j) { - // A[i + j, j] = sin(i*j) - // } - // for (int j = 0; j < 199; ++j) { - // B[i,j] = A[i + j, j] + A[i + j, j+1] - // } - // } - BufHandle aBuf("A", {300, 200}, kInt); - BufHandle bBuf("B", {100, 200}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {i + j, j}, sin(i * j))); - auto forJ2 = For::make( - j, - 0, - 199, - Store::make( - bBuf, - {i, j}, - Add::make( - Load::make(aBuf, {i + j, j}), Load::make(aBuf, {i + j, j + 1})))); - auto forI = For::make(i, 0, 100, Block::make({forJ1, forJ2})); - auto par = Block::make({forI}); - LoopNest::compressBuffer(aBuf.node(), par); - - // There should be no change in the buffer or code. - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[i + j, j] = -# CHECK: for (int j -# CHECK-NEXT: B[i, j] = (A[i + j, j]) + (A[i + j, j + 1]) - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - ASSERT_EQ(aBuf.node()->ndim(), 2); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 300); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); -} - -TEST(LoopNest, compressMultipleBuffers) { - // Input IR: - // for (int i = 0; i < 100; ++i) { - // for (int j = 0; j < 200; ++j) { - // A[i,j] = sin(i*j) - // } - // for (int k = 0; k < 199; ++k) { - // B[i,k] = A[i,k] + A[i, k+1] - // } - // for (int m = 0; m < 50; ++m) { - // C[i,m] = B[i,m] - // } - // } - BufHandle aBuf("A", {100, 200}, kInt); - BufHandle bBuf("B", {100, 200}, kInt); - BufHandle cBuf("C", {100, 200}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - VarHandle k("k", kInt); - VarHandle m("m", kInt); - auto forJ = For::make(j, 0, 200, Store::make(aBuf, {i, j}, sin(i * j))); - auto forK = For::make( - k, - 0, - 199, - Store::make( - bBuf, - {i, k}, - Add::make(Load::make(aBuf, {i, k}), Load::make(aBuf, {i, k + 1})))); - auto forM = - For::make(m, 0, 50, Store::make(cBuf, {i, m}, Load::make(bBuf, {i, m}))); - auto forI = For::make(i, 0, 100, Block::make({forJ, forK, forM})); - auto par = Block::make({forI}); - - // This should compress all buffers A, B, and C as follows: - // A[100, 200] -> A[1, 200] - // B[100, 200] -> B[1, 200] - // C[100, 200] -> C[1, 1] - LoopNest::compressAllBuffers(par); - - std::ostringstream oss; - oss << *par; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: for (int j -# CHECK-NEXT: A[0, j] = -# CHECK: for (int k -# CHECK-NEXT: B[0, k] = (A[0, k]) + (A[0, k + 1]) -# CHECK: for (int m -# CHECK-NEXT: C[0, 0] = B[0, m] - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); - - ASSERT_EQ(aBuf.node()->ndim(), 2); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); - IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); - ASSERT_EQ(bBuf.node()->ndim(), 2); - IS_IMM_WITH_VAL(Int, bBuf.node()->dim(0), 1); - IS_IMM_WITH_VAL(Int, bBuf.node()->dim(1), 200); - ASSERT_EQ(cBuf.node()->ndim(), 2); - IS_IMM_WITH_VAL(Int, cBuf.node()->dim(0), 1); - IS_IMM_WITH_VAL(Int, cBuf.node()->dim(1), 1); -} - -TEST(LoopNest, sanitizeNames) { - std::vector dim_args; - // Let's pick names that would overlap with default index names if not - // sanitized properly: - dim_args.emplace_back(ExprHandle(alloc("i", kInt))); - dim_args.emplace_back(ExprHandle(alloc("N:2", kInt))); - // Now let's create a many dimensions so that we had to use the same letter - // for different loops - for (int i = 0; i < 10; i++) { - dim_args.emplace_back(ExprHandle(alloc("N", kInt))); - } - - // Now create two Computes with conflicting after sanitization names: - Tensor X = Compute("$X:!", dim_args, [&](const std::vector& v) { - return v[0] + v[1] + v[9] + 1; - }); - Tensor Y = Reduce( - "%X\"+", - {}, - Sum(), - [&](const std::vector& v) { return X.load(v); }, - dim_args); - - // Finally, let's verify what we got after sanitization: - LoopNest l({X, Y}); - StmtPtr s = l.root_stmt(); - LoopNest::sanitizeNames(s); - - std::ostringstream oss; - oss << *s; - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i = 0; i < i_1; i++) { -# CHECK-NEXT: for (int j = 0; j < N_2_1; j++) { -# CHECK-NEXT: for (int k = 0; k < N_9; k++) { -# CHECK-NEXT: for (int l = 0; l < N_8; l++) { -# CHECK-NEXT: for (int m = 0; m < N_7; m++) { -# CHECK-NEXT: for (int n = 0; n < N_6; n++) { -# CHECK-NEXT: for (int o = 0; o < N_5; o++) { -# CHECK-NEXT: for (int p = 0; p < N_4; p++) { -# CHECK-NEXT: for (int i1 = 0; i1 < N_3; i1++) { -# CHECK-NEXT: for (int j1 = 0; j1 < N_2; j1++) { -# CHECK-NEXT: for (int k1 = 0; k1 < N_1; k1++) { -# CHECK-NEXT: for (int l1 = 0; l1 < N; l1++) { -# CHECK-NEXT: v_X__[i, j, k, l, m, n, o, p, i1, j1, k1, l1] = ((i + j) + j1) + 1; -# CHECK: v_X___1 = int(0); -# CHECK-NEXT: for (int i_2 = 0; i_2 < i_1; i_2++) { -# CHECK-NEXT: for (int j_1 = 0; j_1 < N_2_1; j_1++) { -# CHECK-NEXT: for (int k_1 = 0; k_1 < N_9; k_1++) { -# CHECK-NEXT: for (int l_1 = 0; l_1 < N_8; l_1++) { -# CHECK-NEXT: for (int m_1 = 0; m_1 < N_7; m_1++) { -# CHECK-NEXT: for (int n_1 = 0; n_1 < N_6; n_1++) { -# CHECK-NEXT: for (int o_1 = 0; o_1 < N_5; o_1++) { -# CHECK-NEXT: for (int p_1 = 0; p_1 < N_4; p_1++) { -# CHECK-NEXT: for (int i1_1 = 0; i1_1 < N_3; i1_1++) { -# CHECK-NEXT: for (int j1_1 = 0; j1_1 < N_2; j1_1++) { -# CHECK-NEXT: for (int k1_1 = 0; k1_1 < N_1; k1_1++) { -# CHECK-NEXT: for (int l1_1 = 0; l1_1 < N; l1_1++) { -# CHECK-NEXT: v_X___1 = ReduceOp((v_X___1) + (v_X__[i_2, j_1, k_1, l_1, m_1, n_1, o_1, p_1, i1_1, j1_1, k1_1, l1_1]), reduce_args={i_2, j_1, k_1, l_1, m_1, n_1, o_1, p_1, i1_1, j1_1, k1_1, l1_1}); - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_memdependency.cpp b/test/cpp/tensorexpr/test_memdependency.cpp deleted file mode 100644 index 5db84eab1f50..000000000000 --- a/test/cpp/tensorexpr/test_memdependency.cpp +++ /dev/null @@ -1,3252 +0,0 @@ -#include -#include - -#include -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -// Test helper function used to determine if two regions of a buffer have an -// overlap. No Overlap & partial overlap is obvious. Contains means A is -// larger and fully encloses B, while ContainedOrEqual is the reverse. Equal -// ranges are ContainedOrEqual. -TEST(MemDependency, BoundOverlap) { - using namespace analysis; - - auto CB = [](int s, int e) { - return Bound(alloc(s), alloc(e)); - }; - - // Sanity check 3 overlap cases. - ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(0, 0), CB(0, 0))); - ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 3), CB(2, 5))); - ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 0), CB(1, 1))); - - // Partial overlap works in either order. - ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 10), CB(7, 14))); - ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(7, 14), CB(0, 10))); - - // Total Overlap works when one bound encloses the other, and returns which. - ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(7, 9))); - ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 15), CB(0, 16))); - - // Total overlap works when the bounds are an identical range, returns - // ContainedOrEqual. - ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 15), CB(2, 15))); - - // Total overlap when only one end of the bound matches. - ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(2, 10))); - ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(3, 15))); - ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(0, 10), CB(0, 9))); - ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 10), CB(2, 15))); - ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(3, 15), CB(2, 15))); - - // No overlap when a < b. - ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 2), CB(5, 10))); - ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(2, 2), CB(3, 3))); - ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(100, 120), CB(130, 130))); - - // No overlap when a > b. - ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(5, 10), CB(0, 2))); - ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(3, 3), CB(2, 2))); - ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(130, 130), CB(100, 120))); - - // No overlap when adjacent. - ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 100), CB(101, 120))); - ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(2, 3), CB(0, 1))); - - // Partial overlap when middle bounds match. - ASSERT_EQ( - OverlapKind::PartialOverlap, boundOverlap(CB(0, 100), CB(100, 120))); - ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 2), CB(2, 4))); - ASSERT_EQ( - OverlapKind::PartialOverlap, boundOverlap(CB(100, 120), CB(0, 100))); - ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(2, 3), CB(1, 2))); - - // Total overlap when one bound is single length over one end of the other. - ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(15, 15))); - ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(2, 2))); - ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 2), CB(2, 15))); - ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(15, 15), CB(2, 15))); -} - -TEST(MemDependency, BoundComparison) { - using namespace analysis; - - auto CB = [](int s, int e) { - return Bound(alloc(s), alloc(e)); - }; - - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kEQ)); - ASSERT_EQ( - CmpEvalResult::True, - compareBound(CB(10, 10), CB(10, 10), CompareSelectOperation::kEQ)); - ASSERT_EQ( - CmpEvalResult::False, - compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kEQ)); - ASSERT_EQ( - CmpEvalResult::False, - compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kEQ)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kEQ)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 40), CB(20, 30), CompareSelectOperation::kEQ)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kEQ)); - - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kNE)); - ASSERT_EQ( - CmpEvalResult::False, - compareBound(CB(10, 10), CB(10, 10), CompareSelectOperation::kNE)); - ASSERT_EQ( - CmpEvalResult::True, - compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kNE)); - ASSERT_EQ( - CmpEvalResult::True, - compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kNE)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kNE)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 40), CB(20, 30), CompareSelectOperation::kEQ)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kNE)); - - ASSERT_EQ( - CmpEvalResult::True, - compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kLT)); - ASSERT_EQ( - CmpEvalResult::False, - compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kLT)); - ASSERT_EQ( - CmpEvalResult::False, - compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kLT)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kLT)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kLT)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kLT)); - - ASSERT_EQ( - CmpEvalResult::False, - compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kGE)); - ASSERT_EQ( - CmpEvalResult::True, - compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kGE)); - ASSERT_EQ( - CmpEvalResult::True, - compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kGE)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kGE)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kGE)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kGE)); - - ASSERT_EQ( - CmpEvalResult::False, - compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kGT)); - ASSERT_EQ( - CmpEvalResult::False, - compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kGT)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kGT)); - ASSERT_EQ( - CmpEvalResult::True, - compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kGT)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kGT)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kGT)); - - ASSERT_EQ( - CmpEvalResult::True, - compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kLE)); - ASSERT_EQ( - CmpEvalResult::True, - compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kLE)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kLE)); - ASSERT_EQ( - CmpEvalResult::False, - compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kLE)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kLE)); - ASSERT_EQ( - CmpEvalResult::NotDetermined, - compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kLE)); -} - -TEST(MemDependency, BoundOverlapSymbolic) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - VarHandle w("w", kInt); - - using namespace analysis; - - auto CB = [](ExprHandle s, ExprHandle e) { - return Bound(s.node(), e.node()); - }; - - // Sanity check cases where the start and end is symbolic but the diff is - // constant. - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(x, x), CB(x, x))); - ASSERT_EQ( - OverlapKind::PartialOverlap, - boundOverlap(CB(x, x + 3), CB(x + 2, x + 5))); - ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(x, x), CB(x + 1, x + 1))); - - // We can't infer the sign of y, so cannot tell whether adding y is larger or - // smaller than y/2. - ASSERT_EQ( - OverlapKind::PartialOverlap, - boundOverlap(CB(x, x + y), CB(x, x + y / 2))); - - // No information about this bound, have to take the most conservative option: - // there may be an overlap. - ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(x, y), CB(z, w))); - - // Math on opaque terms works. - ASSERT_EQ( - OverlapKind::ContainedOrEqual, - boundOverlap(CB(x + w, y - z), CB(x + w, y - z))); - // Even requiring simplification. - ASSERT_EQ( - OverlapKind::ContainedOrEqual, - boundOverlap(CB(x - w - w, y), CB(x - w * 2, y))); -} - -// Tests the helper function for overlap of multi dimensional indices bounds. -// This uses boundOverlap on each dimension and return the "lowest" kind of -// overlap. -TEST(MemDependency, BoundOverlapMultiDim) { - using namespace analysis; - - auto CB = [](int s, int e) { - return Bound(alloc(s), alloc(e)); - }; - - // Sanity check one dimensional cases. - ASSERT_EQ(OverlapKind::ContainedOrEqual, overlaps({CB(0, 0)}, {CB(0, 0)})); - ASSERT_EQ(OverlapKind::NoOverlap, overlaps({CB(0, 2)}, {CB(5, 10)})); - ASSERT_EQ( - OverlapKind::PartialOverlap, overlaps({CB(0, 100)}, {CB(100, 120)})); - - // Total overlap in 3 dims. - ASSERT_EQ( - OverlapKind::ContainedOrEqual, - overlaps({CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(0, 4)})); - ASSERT_EQ( - OverlapKind::ContainedOrEqual, - overlaps( - {CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(0, 10)})); - - // Total overlap in 2 dims, no overlap in another. - ASSERT_EQ( - OverlapKind::NoOverlap, - overlaps( - {CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(5, 10)})); - - // Total overlap in 2 dims, partial overlap in another. - ASSERT_EQ( - OverlapKind::PartialOverlap, - overlaps( - {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 2), CB(0, 5), CB(5, 10)})); - // This case is most important, so verify the overlap in any dim. (dim 2) - ASSERT_EQ( - OverlapKind::PartialOverlap, - overlaps({CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 2), CB(2, 6), CB(0, 5)})); - // Dim 1. - ASSERT_EQ( - OverlapKind::PartialOverlap, - overlaps({CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(1, 3), CB(0, 5), CB(0, 5)})); - // Total overlap in 1 dim, partial in 2. - ASSERT_EQ( - OverlapKind::PartialOverlap, - overlaps( - {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(2, 6), CB(0, 5), CB(5, 10)})); - // Total overlap, partial overlap, no overlap. - ASSERT_EQ( - OverlapKind::NoOverlap, - overlaps( - {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(2, 6), CB(11, 15), CB(0, 5)})); - - // Total overlap (B) in 2 dims, total overlap (A) in another. - ASSERT_EQ( - OverlapKind::Contains, - overlaps({CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 3), CB(0, 4)})); - - // Total overlap (A) in 2 dims, total overlap (B) in another. - ASSERT_EQ( - OverlapKind::Contains, - overlaps( - {CB(0, 12), CB(0, 15), CB(0, 4)}, {CB(0, 2), CB(0, 3), CB(0, 14)})); - - // Total (B), No Overlap, Total (A). - ASSERT_EQ( - OverlapKind::NoOverlap, - overlaps( - {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 6), CB(11, 15), CB(1, 2)})); -} - -// Test the helper we use to subtract bounds: returns the regions(s) of A which -// remain after removing the region of B. -TEST(MemDependency, BoundSubtract) { - using namespace analysis; - - auto CB = [](int s, int e) { - return Bound(alloc(s), alloc(e)); - }; - auto EQ = [](const IndexBounds& x, const IndexBounds& y) { - return indexBoundsEquals(x, y); - }; - - // One element subtract. - ASSERT_EQ(subtractBound(CB(0, 0), CB(0, 0)).size(), 0); - ASSERT_EQ(subtractBound(CB(5, 5), CB(5, 5)).size(), 0); - - // No Overlap. - ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(2, 2)), {CB(5, 5)})); - ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(0, 4)), {CB(5, 5)})); - - // one side overlap. - ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(4, 7)), {CB(1, 3)})); - ASSERT_TRUE(EQ(subtractBound(CB(0, 5), CB(5, 7)), {CB(0, 4)})); - ASSERT_TRUE(EQ(subtractBound(CB(4, 5), CB(1, 4)), {CB(5, 5)})); - ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(0, 4)), {CB(5, 5)})); - - // both sides overlap. - ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(0, 7)), {})); - ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(5, 7)), {})); - - // internal overlap. - ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(2, 3)), {CB(1, 1), CB(4, 5)})); - ASSERT_TRUE(EQ(subtractBound(CB(0, 5), CB(2, 4)), {CB(0, 1), CB(5, 5)})); -} - -TEST(MemDependency, BoundSubtractSymbolic) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - VarHandle w("w", kInt); - - using namespace analysis; - - auto CB = [](ExprHandle s, ExprHandle e) { - return Bound(s.node(), e.node()); - }; - auto EQ = [](const IndexBounds& x, const IndexBounds& y) { - return indexBoundsEquals(x, y); - }; - - // One element subtract. - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - ASSERT_TRUE(EQ(subtractBound(CB(x, x), CB(x, x)), {})); - ASSERT_TRUE(EQ(subtractBound(CB(x + 1, x + 1), CB(x + 1, x + 1)), {})); - ASSERT_TRUE(EQ(subtractBound(CB(x * 2, x * 2), CB(x * 2, x * 2)), {})); - - // Subtract constant range low. - ASSERT_TRUE( - EQ(subtractBound(CB(x, x + 10), CB(x, x + 4)), {CB(x + 5, x + 10)})); - // Subtract constant range high. - ASSERT_TRUE( - EQ(subtractBound(CB(x, x + 10), CB(x + 6, x + 12)), {CB(x, x + 5)})); - // Subtract constant range total overlap. - ASSERT_TRUE(EQ(subtractBound(CB(x, x + 10), CB(x, x + 10)), {})); - ASSERT_TRUE(EQ(subtractBound(CB(x + 2, x + 10), CB(x, x + 12)), {})); - // Subtract constant range internal. - ASSERT_TRUE( - EQ(subtractBound(CB(x, x + 10), CB(x + 3, x + 7)), - {CB(x, x + 2), CB(x + 8, x + 10)})); - - // Size is inferable but not constant, only works with a single var. - ASSERT_TRUE(EQ(subtractBound(CB(0, x), CB(0, x * 2)), {})); - ASSERT_TRUE(EQ(subtractBound(CB(0, x * 2), CB(0, x - 1)), {CB(x, x * 2)})); - - // Size is not inferable. - ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(z, w)), {CB(x, y)})); - ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(x, z)), {CB(x, y)})); - ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(0, x)), {CB(x, y)})); - ASSERT_TRUE(EQ(subtractBound(CB(x, x), CB(0, 0)), {CB(x, x)})); -} - -// Tests the helper function that does subtraction, but for multi dimensional -// indices bounds. -TEST(MemDependency, BoundSubtractMultiDim) { - using namespace analysis; - - auto CB = [](int s, int e) { - return Bound(alloc(s), alloc(e)); - }; - auto EQ = [](std::vector x, std::vector y) { - if (x.size() != y.size()) { - return false; - } - for (auto i = 0U; i < x.size(); ++i) { - if (!indexBoundsEquals(x[i], y[i])) { - return false; - } - } - return true; - }; - - // sanity check one dimension. - ASSERT_TRUE(EQ(subtractIndicesBounds({CB(0, 9)}, {CB(0, 9)}), {})); - ASSERT_TRUE(EQ(subtractIndicesBounds({CB(3, 9)}, {CB(0, 12)}), {})); - ASSERT_TRUE( - EQ(subtractIndicesBounds({CB(0, 12)}, {CB(0, 9)}), {{CB(10, 12)}})); - ASSERT_TRUE( - EQ(subtractIndicesBounds({CB(0, 12)}, {CB(3, 12)}), {{CB(0, 2)}})); - ASSERT_TRUE(EQ( - subtractIndicesBounds({CB(0, 9)}, {CB(1, 8)}), {{CB(0, 0)}, {CB(9, 9)}})); - - // Multi dim total overlap. - ASSERT_TRUE(EQ( - subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 9), CB(0, 2)}), {})); - ASSERT_TRUE(EQ( - subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 10), CB(0, 20)}), {})); - - // Multi dim one way partial in dim 1. - ASSERT_TRUE( - EQ(subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 3), CB(0, 2)}), - {{CB(4, 9), CB(0, 2)}})); - - // Multi dim one way partial in dim 2. - ASSERT_TRUE( - EQ(subtractIndicesBounds({CB(0, 9), CB(0, 20)}, {CB(0, 9), CB(0, 10)}), - {{CB(0, 9), CB(11, 20)}})); - - // Partial overlap in 2 dims. - ASSERT_TRUE( - EQ(subtractIndicesBounds({CB(0, 5), CB(0, 5)}, {CB(2, 8), CB(2, 8)}), - {{CB(0, 1), CB(0, 5)}, {CB(2, 5), CB(0, 1)}})); - - // Partial overlap in 3 dims. - ASSERT_TRUE( - EQ(subtractIndicesBounds( - {CB(0, 5), CB(0, 5), CB(0, 5)}, {CB(2, 8), CB(2, 8), CB(2, 8)}), - {{CB(0, 1), CB(0, 5), CB(0, 5)}, - {CB(2, 5), CB(0, 1), CB(0, 5)}, - {CB(2, 5), CB(2, 5), CB(0, 1)}})); -} - -// Tests the multi dimensional subtraction code for bounds that cannot be fully -// materialized. -TEST(MemDependency, BoundSubtractMultiDimSymbolic) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - using namespace analysis; - - auto CB = [](ExprHandle s, ExprHandle e) { - return Bound(s.node(), e.node()); - }; - - auto EQ = [](std::vector x, std::vector y) { - if (x.size() != y.size()) { - return false; - } - for (auto i = 0U; i < x.size(); ++i) { - if (!indexBoundsEquals(x[i], y[i])) { - return false; - } - } - return true; - }; - - // Cannot determine overlaps. - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - ASSERT_TRUE(EQ(subtractIndicesBounds({CB(x, x)}, {CB(0, 0)}), {{CB(x, x)}})); - - // Various total Overlaps. - ASSERT_TRUE(EQ( - subtractIndicesBounds({CB(x, x), CB(x, x)}, {CB(x, x), CB(x, x)}), {})); - ASSERT_TRUE(EQ( - subtractIndicesBounds({CB(x, y), CB(x, y)}, {CB(x, y), CB(x, y)}), {})); - ASSERT_TRUE(EQ( - subtractIndicesBounds({CB(x, x), CB(y, y)}, {CB(x, x), CB(y, y)}), {})); - ASSERT_TRUE(EQ( - subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(0, y)}), {})); - - // one-way overlap in first dim. - ASSERT_TRUE( - EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x - 5), CB(0, y)}), - {{CB(x - 4, x), CB(0, y)}})); - // second dim. - ASSERT_TRUE( - EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(5, y)}), - {{CB(0, x), CB(0, 4)}})); - - // Internal overlap in first dim. - ASSERT_TRUE( - EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(2, x - 5), CB(0, y)}), - {{CB(0, 1), CB(0, y)}, {CB(x - 4, x), CB(0, y)}})); - // second dim. - ASSERT_TRUE(EQ( - subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(10, y - 10)}), - {{CB(0, x), CB(0, 9)}, {CB(0, x), CB(y - 9, y)}})); - - // Overlap in both dimensions. - ASSERT_TRUE( - EQ(subtractIndicesBounds( - {CB(0, x), CB(0, y)}, {CB(5, x - 5), CB(10, y - 10)}), - { - {CB(0, 4), CB(0, y)}, - {CB(x - 4, x), CB(0, y)}, - {CB(0, x), CB(0, 9)}, - {CB(0, x), CB(y - 9, y)}, - })); -} - -// Simple check that the analyzer does anything at all... -TEST(MemDependency, MemDependencyCheckerSimple) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - - analysis::MemDependencyChecker analyzer; - - /* - * A[0] = 3; - * B[0] = A[0] + 1; - */ - - StorePtr aStore = Store::make(a, {0}, 3); - StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {0}), 1)); - - StmtPtr stmt = Block::make({aStore, bStore}); - - stmt->accept(&analyzer); - - ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore)); - ASSERT_FALSE(analyzer.dependsIndirectly(aStore, bStore)); - // sanity check, but anything that depends directly must depend indirectly. - ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aStore)); -} - -// Check that there is a difference between direct and indirect dependence. -TEST(MemDependency, MemDependencyCheckerMultiStmt) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - BufHandle c("C", {1}, kInt); - - analysis::MemDependencyChecker analyzer; - - /* - * A[0] = 3; - * B[0] = A[0]; - * C[0] = B[0] + 1; - */ - - StorePtr aStore = Store::make(a, {0}, 3); - StorePtr bStore = Store::make(b, {0}, Load::make(a, {0})); - StorePtr cStore = Store::make(c, {0}, Add::make(Load::make(b, {0}), 1)); - - StmtPtr stmt = Block::make({aStore, bStore, cStore}); - - stmt->accept(&analyzer); - - // C depends on A indirectly. - ASSERT_FALSE(analyzer.dependsDirectly(cStore, aStore)); - ASSERT_TRUE(analyzer.dependsIndirectly(cStore, aStore)); - - // C depends on B directly, which depends on A directly. - ASSERT_TRUE(analyzer.dependsDirectly(cStore, bStore)); - ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore)); - - // Dependency goes top to bottom only. - ASSERT_FALSE(analyzer.dependsIndirectly(bStore, cStore)); - ASSERT_FALSE(analyzer.dependsIndirectly(aStore, bStore)); - ASSERT_FALSE(analyzer.dependsIndirectly(aStore, cStore)); -} - -// Verify that we do filter writes that are totally overlapped by later writes. -TEST(MemDependency, MemDependencyCheckerOverlap) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - - analysis::MemDependencyChecker analyzer; - - /* - * A[0] = 3; - * A[0] = 6; - * B[0] = A[0] + 1; - */ - - StorePtr aStore = Store::make(a, {0}, 3); - StorePtr a2Store = Store::make(a, {0}, 6); - StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {0}), 1)); - - StmtPtr stmt = Block::make({aStore, a2Store, bStore}); - - stmt->accept(&analyzer); - - // B store depends on second A store but not first since it is completely - // overlapped. - ASSERT_TRUE(analyzer.dependsIndirectly(bStore, a2Store)); - ASSERT_FALSE(analyzer.dependsIndirectly(bStore, aStore)); - - // No dependency between either A store. - ASSERT_FALSE(analyzer.dependsIndirectly(aStore, a2Store)); - ASSERT_FALSE(analyzer.dependsIndirectly(a2Store, aStore)); -} - -// Verify that bounds match loop iterations, and that dependencies progress -// across loop scopes. -TEST(MemDependency, MemDependencyCheckerLoop) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - VarHandle x("x", kInt); - - using namespace analysis; - - MemDependencyChecker analyzer; - - /* - * for (int x = 0; x < 10; ++x) { - * A[x] = x; - * } - * B[0] = A[0] + 1; - */ - - StorePtr aStore = Store::make(a, {x}, x); - StmtPtr loop = For::make(x, 0, 10, aStore); - StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {4}), 1)); - - StmtPtr stmt = Block::make({loop, bStore}); - - stmt->accept(&analyzer); - - // Same A->B dependency. - ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore)); - - // B depends on the loop. - ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop)); - // A is in the loop but does not depend on any loop iteration. - ASSERT_FALSE(analyzer.dependsIndirectly(aStore, loop)); - - auto aStoreAccess = analyzer.accessFor(aStore); - ASSERT_NE(aStoreAccess, nullptr); - - // It should have bounds covering the range of x: 0 <= x < 10. - ASSERT_TRUE(indexBoundsEquals( - aStoreAccess->bounds(), {Bound(alloc(0), alloc(9))})); -} - -// Reductions should promote dependencies as well. -TEST(MemDependency, MemDependencyCheckerLoopReduce) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - - using namespace analysis; - - MemDependencyChecker analyzer; - - /* - * A[0] = 0; - * for (int x = 0; x < 10; ++x) { - * A[0] = A[x] + 1; - * } - * B[0] = A[0]; - */ - - StorePtr aInit = Store::make(a, {0}, 0); - ExprHandle reduce = Sum()(a, 1, {x}, {x}); - StorePtr aReduce = Store::make(a, {0}, reduce); - StmtPtr loop = For::make(x, 0, 10, aReduce); - StorePtr bStore = Store::make(b, {0}, Load::make(a, {0})); - - StmtPtr stmt = Block::make({aInit, loop, bStore}); - - stmt->accept(&analyzer); - - // B -> A. - ASSERT_TRUE(analyzer.dependsDirectly(bStore, aReduce)); - - // B depends indirectly on the initializer of A, since the reduction depends - // on it. - ASSERT_FALSE(analyzer.dependsDirectly(bStore, aInit)); - ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aInit)); - - ASSERT_TRUE(analyzer.dependsDirectly(aReduce, aInit)); - - // B depends on the loop. - ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop)); - // A is in the loop and depends on other iterations. - ASSERT_TRUE(analyzer.dependsDirectly(aReduce, loop)); - - // The loop contents depend on the initializer too. - ASSERT_TRUE(analyzer.dependsDirectly(loop, aInit)); - - // Find loads within the reduction: - auto reduceLoads = NodeFinder::find(reduce.node()); - // Pull out the access for the load inside the loop. - for (auto load : reduceLoads) { - auto loopLoad = analyzer.accessFor(load); - // It should have 10 element long bounds. - ASSERT_TRUE(indexBoundsEquals( - loopLoad->bounds(), {Bound(alloc(0), alloc(9))})); - } -} - -// Lowering a reduction doesn't affect dependency analysis. -TEST(MemDependency, MemDependencyCheckerLoopReduceExpanded) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - - using namespace analysis; - - MemDependencyChecker analyzer; - - /* - * A[0] = 0; - * for (int x = 0; x < 10; ++x) { - * A[0] = A[x] + 1; - * } - * B[0] = A[0]; - */ - - StorePtr aInit = Store::make(a, {0}, 0); - ExprHandle aLoad = Load::make(a, {x}); - StorePtr aReduce = Store::make(a, {0}, Add::make(aLoad, 1)); - StmtPtr loop = For::make(x, 0, 10, aReduce); - StorePtr bStore = Store::make(b, {0}, Load::make(a, {0})); - - StmtPtr stmt = Block::make({aInit, loop, bStore}); - - stmt->accept(&analyzer); - - // B -> A. - ASSERT_TRUE(analyzer.dependsDirectly(bStore, aReduce)); - - // B depends indirectly on the initializer of A, since the reduction depends - // on it. - ASSERT_FALSE(analyzer.dependsDirectly(bStore, aInit)); - ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aInit)); - - ASSERT_TRUE(analyzer.dependsDirectly(aReduce, aInit)); - - // B depends on the loop. - ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop)); - // A is in the loop and depends on other iterations. - ASSERT_TRUE(analyzer.dependsDirectly(aReduce, loop)); - - // The loop contents depend on the initializer too. - ASSERT_TRUE(analyzer.dependsDirectly(loop, aInit)); - - // Pull out the access for the store inside the loop. - auto loopLoad = analyzer.accessFor(aLoad.node()); - // It should have 10 element long bounds. - ASSERT_TRUE(indexBoundsEquals( - loopLoad->bounds(), {Bound(alloc(0), alloc(9))})); -} - -// Can determine dependencies of outputs, through to inputs. -TEST(MemDependency, MemDependencyCheckerInputsOutputs) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - - // initialize analyzer with inputs and outputs. - analysis::MemDependencyChecker analyzer({a}, {b}); - - // Here's a Relu. - /* - * for (int x = 0; x < 10; ++x) { - * B[x] = Max(A[x], 0); - * } - */ - - ExprHandle aLoad = Load::make(a, {x}); - StorePtr bStore = Store::make(b, {x}, Max::make(aLoad, 0, true)); - StmtPtr loop = For::make(x, 0, 10, bStore); - - StmtPtr stmt = Block::make({loop}); - - stmt->accept(&analyzer); - - // Output depends indirectly on input. - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); - // aLoad depends directly on the input A. - ASSERT_TRUE(analyzer.dependsDirectly(aLoad.node(), a.node())); - // bStore therefore depends directly on the input A. - ASSERT_TRUE(analyzer.dependsDirectly(bStore, a.node())); - // The output depends directly on the store. - ASSERT_TRUE(analyzer.dependsDirectly(b.node(), bStore)); - - // Check AccessInfo based overloads. - auto input = analyzer.input(a.node()); - auto output = analyzer.output(b.node()); - - // Output depends indirectly on input. - ASSERT_TRUE(analyzer.dependsIndirectly(output, input)); - // Not directly. - ASSERT_FALSE(analyzer.dependsDirectly(output, input)); - // Not in reverse order. - ASSERT_FALSE(analyzer.dependsIndirectly(input, output)); - - // output -> bStore -> bLoad -> input. - auto storeAccess = analyzer.accessFor(bStore); - auto loadAccess = analyzer.accessFor(aLoad.node()); - - ASSERT_TRUE(analyzer.dependsDirectly(output, storeAccess)); - ASSERT_TRUE(analyzer.dependsDirectly(loadAccess, input)); -} - -// Can tell if an output does not depend on an input. -TEST(MemDependency, MemDependencyCheckerOutputDoesntDepend) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - - // initialize analyzer with inputs and outputs. - analysis::MemDependencyChecker analyzer({a}, {b}); - - // Here's a dumb Relu. - /* - * for (int x = 0; x < 10; ++x) { - * B[x] = Max(x, 0); - * } - */ - - StorePtr bStore = Store::make(b, {x}, Max::make(x, 0, true)); - StmtPtr loop = For::make(x, 0, 10, bStore); - - StmtPtr stmt = Block::make({loop}); - - stmt->accept(&analyzer); - - // Output does not depend indirectly on input. - ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), a.node())); - - // The output still depends directly on the store. - ASSERT_TRUE(analyzer.dependsDirectly(b.node(), bStore)); - - // Check AccessInfo based overloads. - auto input = analyzer.input(a.node()); - auto output = analyzer.output(b.node()); - - // Output does not depend indirectly on input. - ASSERT_FALSE(analyzer.dependsIndirectly(output, input)); -} - -// Verify different loop extents produce accesses with different bounds, and -// that later accesses find dependencies that overlap their entire bound range. -TEST(MemDependency, MemDependencyCheckerLoopBounds) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - BufHandle c("C", {10}, kInt); - VarHandle x("x", kInt); - using namespace analysis; - - MemDependencyChecker analyzer({a}, {c}); - - // This enables using the execution order of the loops to determine if some - // loops are self dependent or not. - analyzer.allowLoopExecutionOrderAnalysis(); - - /* - * for (int x = 1; x < 10; ++x) { - * B[x] = A[x]; - * } - * for (int x = 1; x < 9; ++x) { - * B[x] = B[x] * 2; - * } - * for (int x = 3; x < 4; ++x) { - * C[x] = A[x]; - * } - * for (int x = 0; x < 10; ++x) { - * C[x] = B[x]; - * } - */ - - std::vector stmts( - {For::make(x, 1, 10, Store::make(b, {x}, Load::make(a, {x}))), - For::make( - x, 1, 9, Store::make(b, {x}, Mul::make(Load::make(b, {x}), 2))), - For::make(x, 3, 4, Store::make(c, {x}, Load::make(a, {x}))), - For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x})))}); - - StmtPtr stmt = Block::make(stmts); - - stmt->accept(&analyzer); - - auto input = analyzer.input(a.node()); - auto output = analyzer.output(c.node()); - - // sanity check Output -> Input. - ASSERT_TRUE(analyzer.dependsIndirectly(output, input)); - - // Check the For loop dependencies: - - // Last write to C depends on both writes to B since they contain the last - // write to at least one element. - ASSERT_TRUE(analyzer.dependsIndirectly(stmts[3], stmts[1])); - ASSERT_TRUE(analyzer.dependsIndirectly(stmts[3], stmts[0])); - - // The last write to C does not depend on the other write to C. - ASSERT_FALSE(analyzer.dependsIndirectly(stmts[3], stmts[2])); - - auto CB = [](int s, int e) { - return Bound(alloc(s), alloc(e)); - }; - auto EQ = [](const IndexBounds& x, const IndexBounds& y) { - return indexBoundsEquals(x, y); - }; - - /* 0. Input: A[(0, 9)] - dependents: 1 5 - * 1. Load: A[(1, 9)] - depends on: 0 - dependents: 2 - * 2. Store: B[(1, 9)] - depends on: 1 - dependents: 3 7 - * 3. Load: B[(1, 8)] - depends on: 2 - dependents: 4 - * 4. Store: B[(1, 8)] - depends on: 3 - dependents: 7 - * 5. Load: A[(3, 3)] - depends on: 0 - dependents: 6 - * 6. Store: C[(3, 3)] - depends on: 5 - * 7. Load: B[(0, 9)] - depends on: 2 4 - dependents: 8 - * 8. Store: C[(0, 9)] - depends on: 7 - dependents: 9 - * 9. Output: C[(0, 9)] - depends on: 8 - */ - - // Now let's look at the bounds of each access. - // There are 9 accesses in this Stmt, so this is exhaustive, we won't do this - // much. - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 10); - VarPtr aVar = a.node()->base_handle(); - VarPtr bVar = b.node()->base_handle(); - VarPtr cVar = c.node()->base_handle(); - - // The first access is the input A. - ASSERT_EQ(history[0]->type(), AccessType::Input); - ASSERT_EQ(history[0]->var(), aVar); - // It has the bounds of the producing Input. - ASSERT_TRUE(EQ(history[0]->bounds(), {CB(0, 9)})); - // sanity check the input we retrieved earlier matches. - ASSERT_EQ(history[0], input); - - // The second access is the load of A in the first loop. - ASSERT_EQ(history[1]->type(), AccessType::Load); - ASSERT_EQ(history[1]->var(), aVar); - // It has the bounds of the loop, i.e. start == 1. - ASSERT_TRUE(EQ(history[1]->bounds(), {CB(1, 9)})); - // It reads from A, so it should have a dependency on the last write to this - // range - with is the input. - ASSERT_EQ(history[1]->dependencies().size(), 1); - ASSERT_TRUE(history[1]->hasDependency(history[0])); - - // The third access is the store into B in the first loop. - ASSERT_EQ(history[2]->type(), AccessType::Store); - ASSERT_EQ(history[2]->var(), bVar); - // It also has the bounds of the loop, i.e. start == 1. - ASSERT_TRUE(EQ(history[2]->bounds(), {CB(1, 9)})); - // The previous load is in its RHS, so it depends on it. - ASSERT_EQ(history[2]->dependencies().size(), 1); - ASSERT_TRUE(history[2]->hasDependency(history[1])); - - // The third access is the load from B in the second loop. - ASSERT_EQ(history[3]->type(), AccessType::Load); - ASSERT_EQ(history[3]->var(), bVar); - // It has the bounds of the second loop, i.e. >= 1 < 9. - ASSERT_TRUE(EQ(history[3]->bounds(), {CB(1, 8)})); - // It reads from B in a smaller range, so should depend on the previous - // store. - ASSERT_EQ(history[3]->dependencies().size(), 1); - ASSERT_TRUE(history[3]->hasDependency(history[2])); - - // The fourth: the store to B in the second loop. - ASSERT_EQ(history[4]->type(), AccessType::Store); - ASSERT_EQ(history[4]->var(), bVar); - // It also has the bounds of the second loop. - ASSERT_TRUE(EQ(history[4]->bounds(), {CB(1, 8)})); - // The previous load is in its RHS, so it depends on it as before. - ASSERT_EQ(history[4]->dependencies().size(), 1); - ASSERT_TRUE(history[4]->hasDependency(history[3])); - - // The fifth access is the load is from the 3rd loop, and skips previous B - // accesses. - ASSERT_EQ(history[5]->type(), AccessType::Load); - ASSERT_EQ(history[5]->var(), aVar); - // It has the bounds of the third loop: >= 3 < 4. - ASSERT_TRUE(EQ(history[5]->bounds(), {CB(3, 3)})); - // It depends on the last thing to write to A, which is the A input. - ASSERT_EQ(history[5]->dependencies().size(), 1); - ASSERT_TRUE(history[5]->hasDependency(history[0])); - - // Sixth: the store into the output C. - ASSERT_EQ(history[6]->type(), AccessType::Store); - ASSERT_EQ(history[6]->var(), cVar); - // It also has the bounds of the third loop. - ASSERT_TRUE(EQ(history[6]->bounds(), {CB(3, 3)})); - // The previous load is in its RHS, so it depends on it as always. - ASSERT_EQ(history[6]->dependencies().size(), 1); - ASSERT_TRUE(history[6]->hasDependency(history[5])); - - // The seventh access is the load of B in the fourth loop. - ASSERT_EQ(history[7]->type(), AccessType::Load); - ASSERT_EQ(history[7]->var(), bVar); - // It has the bounds of the final loop, >= 0 < 10 - ASSERT_TRUE(EQ(history[7]->bounds(), {CB(0, 9)})); - // The bounds of this read are larger than the bounds of the previous write, - // so it depends on both previous Stores to B. - ASSERT_EQ(history[7]->dependencies().size(), 2); - ASSERT_TRUE(history[7]->hasDependency(history[2])); - ASSERT_TRUE(history[7]->hasDependency(history[4])); - - // Eight: the final store into the output C. - ASSERT_EQ(history[8]->type(), AccessType::Store); - ASSERT_EQ(history[8]->var(), cVar); - // It also has the bounds of the final loop. - ASSERT_TRUE(EQ(history[8]->bounds(), {CB(0, 9)})); - // The previous load is in its RHS, so it depends on it as always. - ASSERT_EQ(history[8]->dependencies().size(), 1); - ASSERT_TRUE(history[8]->hasDependency(history[7])); - - // The last access represents the output Buf. - ASSERT_EQ(history[9]->type(), AccessType::Output); - ASSERT_EQ(history[9]->var(), cVar); - // It has the bounds of the output Buf. - ASSERT_TRUE(EQ(history[9]->bounds(), {CB(0, 9)})); - // sanity check the input we retrieved earlier matches. - ASSERT_EQ(history[9], output); - // It depends on the last write to C only. - ASSERT_EQ(history[9]->dependencies().size(), 1); - ASSERT_TRUE(history[9]->hasDependency(history[8])); -} - -// Verify that we can still infer bounds when the loop var is offset. -TEST(MemDependency, MemDependencyCheckerLoopBoundsIndexShift) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - - using namespace analysis; - - MemDependencyChecker analyzer({a}, {b}); - - // This enables using the execution order of the loops to determine if some - // loops are self dependent or not. - analyzer.allowLoopExecutionOrderAnalysis(); - - /* - * for (int x = 1; x < 10; x++) { - * A[x] = A[x - 1]; - * } - * for (int x = 0; x < 9; x++) { - * A[x] = A[x + 1]; - * } - * for (int x = 0; x < 9; x++) { - * A[9 - x] = A[8 - x]; - * } - * for (int x = 0; x < 10; x++) { - * A[x] = A[9 - x]; - * } - * for (int x = 0; x < 10; x++) { - * B[x] = A[x]; - * } - */ - - StmtPtr stmt = Block::make( - {For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))), - For::make(x, 0, 9, Store::make(a, {x}, Load::make(a, {x + 1}))), - For::make( - x, - 0, - 9, - Store::make( - a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x}))), - For::make( - x, 0, 10, Store::make(a, {x}, Load::make(a, {ExprHandle(9) - x}))), - For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x})))}); - - stmt->accept(&analyzer); - - // Sanity check output depends on Input. - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); - - auto CB = [](int s, int e) { - return Bound(alloc(s), alloc(e)); - }; - auto EQ = [](const IndexBounds& x, const IndexBounds& y) { - return indexBoundsEquals(x, y); - }; - - /* 0. Input: A[(0, 9)] - dependents: 1 - * 1. Load: A[(0, 8)] - depends on: 0 2 - dependents: 2 - * 2. Store: A[(1, 9)] - depends on: 1 - dependents: 1 3 - * 3. Load: A[(1, 9)] - depends on: 2 - dependents: 4 - * 4. Store: A[(0, 8)] - depends on: 3 - dependents: 5 7 - * 5. Load: A[(0, 8)] - depends on: 4 - dependents: 6 - * 6. Store: A[(1, 9)] - depends on: 5 - dependents: 7 - * 7. Load: A[(0, 9)] - depends on: 4 6 8 - dependents: 8 - * 8. Store: A[(0, 9)] - depends on: 7 - dependents: 7 9 - * 9. Load: A[(0, 9)] - depends on: 8 - dependents: 10 - * 10. Store: B[(0, 9)] - depends on: 9 - dependents: 11 - * 11. Output: B[(0, 9)] - depends on: 10 - */ - - // Now let's look at the bounds of each access. - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 12); - VarPtr aVar = a.node()->base_handle(); - VarPtr bVar = b.node()->base_handle(); - - // The first access is the input A. - ASSERT_EQ(history[0]->type(), AccessType::Input); - ASSERT_EQ(history[0]->var(), aVar); - // It has the bounds of the producing Input. - ASSERT_TRUE(EQ(history[0]->bounds(), {CB(0, 9)})); - - // The second access is the load A[x-1]. - ASSERT_EQ(history[1]->type(), AccessType::Load); - ASSERT_EQ(history[1]->var(), aVar); - // It has the bounds of the loop modified by the offset of each index, in - // this case -1. - ASSERT_TRUE(EQ(history[1]->bounds(), {CB(0, 8)})); - // It depends on the input, but also the store in the same loop, since - // different iterations of the loop depend on each other. - ASSERT_EQ(history[1]->dependencies().size(), 2); - ASSERT_TRUE(history[1]->hasDependency(history[0])); - ASSERT_TRUE(history[1]->hasDependency(history[2])); - - // The third access is the Store to A[x] in the first loop. - ASSERT_EQ(history[2]->type(), AccessType::Store); - ASSERT_EQ(history[2]->var(), aVar); - // It has no offset on x, so should have the same bounds as the loop. - ASSERT_TRUE(EQ(history[2]->bounds(), {CB(1, 9)})); - - // The fourth access is the load A[x+1] in the second loop. - ASSERT_EQ(history[3]->type(), AccessType::Load); - ASSERT_EQ(history[3]->var(), aVar); - // It has the bounds of the loop (0 <= x < 9) modified by the offset of each - // index, in this case 1. - ASSERT_TRUE(EQ(history[3]->bounds(), {CB(1, 9)})); - // This load totally overlaps the previous write to A, so it depends only on - // it and not the input. - ASSERT_EQ(history[3]->dependencies().size(), 1); - ASSERT_TRUE(history[3]->hasDependency(history[2])); - - // The fifth access is the store to A[x] in the second loop. - ASSERT_EQ(history[4]->type(), AccessType::Store); - ASSERT_EQ(history[4]->var(), aVar); - // It has no offset on x, so should have the same bounds as the loop. - ASSERT_TRUE(EQ(history[4]->bounds(), {CB(0, 8)})); - - // The sixth access is the load to A[8 - x] in the third loop. - ASSERT_EQ(history[5]->type(), AccessType::Load); - ASSERT_EQ(history[5]->var(), aVar); - // It has the bounds of the loop (0 <= x < 9) modified by the offset of each - // index, in this case 8 - x. - // This access has a negative stride, which will be normalized. - ASSERT_TRUE(EQ(history[5]->bounds(), {CB(0, 8)})); - // This load totally overlaps the most recent write to A, so it depends only - // on it and not the input or the first write to A. - ASSERT_EQ(history[5]->dependencies().size(), 1); - ASSERT_TRUE(history[5]->hasDependency(history[4])); - - // The seventh access is the store to A[9 - x] in the third loop. - ASSERT_EQ(history[6]->type(), AccessType::Store); - ASSERT_EQ(history[6]->var(), aVar); - // This store has a negative stride on it's indices, but is normalized - // internally. - ASSERT_TRUE(EQ(history[6]->bounds(), {CB(1, 9)})); - - // The eighth access is the load A[9-x] in the second loop. - ASSERT_EQ(history[7]->type(), AccessType::Load); - ASSERT_EQ(history[7]->var(), aVar); - // It has the bounds of the loop (0 <= x < 9), modified by the offset 9 - x, - // which essentially traverses the loop backwards. - ASSERT_TRUE(EQ(history[7]->bounds(), {CB(0, 9)})); - // This Load has three write dependencies: - ASSERT_EQ(history[7]->dependencies().size(), 3); - // * The previous store (#6) for elements 1-9 - ASSERT_TRUE(history[7]->hasDependency(history[6])); - // * An earlier store (#4) covering element 0 - ASSERT_TRUE(history[7]->hasDependency(history[4])); - // * A future store inside this loop, since this loop modifies the buffer - // in a non distinct way (due to the load and store having different access - // strides). - ASSERT_TRUE(history[7]->hasDependency(history[8])); - - // The ninth access is the store to A[x] in the fourth loop. - ASSERT_EQ(history[8]->type(), AccessType::Store); - ASSERT_EQ(history[8]->var(), aVar); - // This store has a negative stride on it's indices, but is normalized - // internally. - ASSERT_TRUE(EQ(history[8]->bounds(), {CB(0, 9)})); - - // The tenth and 11th accesses are the copy from A[x] to B[x]. - ASSERT_EQ(history[9]->type(), AccessType::Load); - ASSERT_EQ(history[9]->var(), aVar); - ASSERT_EQ(history[10]->type(), AccessType::Store); - ASSERT_EQ(history[10]->var(), bVar); - - // The last access represents the output Buf. - ASSERT_EQ(history[11]->type(), AccessType::Output); - ASSERT_EQ(history[11]->var(), bVar); - // It has the bounds of the output Buf. - ASSERT_TRUE(EQ(history[11]->bounds(), {CB(0, 9)})); - // It depends on the last write to B only. - ASSERT_EQ(history[11]->dependencies().size(), 1); - ASSERT_TRUE(history[11]->hasDependency(history[10])); - - // ok that's enough of that. -} - -// Check many different cases of loop self dependency - when a load within a -// loop is dependent on a Store later in the same loop but in different -// iteration. This is affected by whether or not we can trust the execution -// order of the loop. -TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - - using namespace analysis; - - // This check assumes that the Stmt has a single Store with a single Load on - // the RHS. - auto isSelfDependent = - [](const std::vector>& history) -> bool { - return history.front()->hasDependency(history.back()); - }; - - { - /* for (int y = 0; y < 10; y++) { - * A[y] = (A[y]) + 1; - * } */ - - // Not self dependent since all loop iterations use a different y. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - y, - 0, - 10, - Block::make({Store::make(a, {y}, Add::make(Load::make(a, {y}), 1))})); - - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int y = 0; y < 10; y++) { - * A[y + 1] = (A[y + 1]) + 1; - * } - */ - - // Not self dependent due to different y (with offset). - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - y, - 0, - 10, - Block::make( - {Store::make(a, {y + 1}, Add::make(Load::make(a, {y + 1}), 1))})); - - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + x; - * } - */ - - // Is self dependent since all loops use a common constant element of A. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, - 0, - 10, - Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))})); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[0] = (B[0]) + x; - * } - */ - - // Is not self dependent because there is no store to the buffer that is - // read. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, - 0, - 10, - Block::make({Store::make(a, {0}, Add::make(Load::make(b, {0}), x))})); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[y] = (A[y]) + x; - * } - */ - - // Is self dependent since all loops use a common symbolic element of A. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, - 0, - 10, - Block::make({Store::make(a, {y}, Add::make(Load::make(a, {y}), x))})); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x] = A[x + 1]; - * } - */ - - // In this case it depends if we are considering execution order. - - MemDependencyChecker analyzer; - - StmtPtr stmt = - For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 1}))); - stmt->accept(&analyzer); - - // With analysis of order disabled, this is self dependent since the read - // from X+1 and the write to X+1 could be in reverse order. - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x] = A[x + 1]; - * } - */ - - MemDependencyChecker analyzer; - analyzer.allowLoopExecutionOrderAnalysis(); - - StmtPtr stmt = - For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 1}))); - stmt->accept(&analyzer); - - // If order analysis is enabled, this is not dependent since the read for - // each element occurs before the write to that element. - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 1; x < 10; x++) { - * A[x] = A[x - 1]; - * } - */ - - MemDependencyChecker analyzer; - - StmtPtr stmt = - For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 1; x < 10; x++) { - * A[x] = A[x - 1]; - * } - */ - - MemDependencyChecker analyzer; - analyzer.allowLoopExecutionOrderAnalysis(); - - StmtPtr stmt = - For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))); - stmt->accept(&analyzer); - - // In this case, even with order analysis the Load is dependent on the - // Store, since the write to X occurs before the read from X. - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 9; x++) { - * A[9 - x] = A[8 - x]; - * } - */ - - // Still works if the execution order is reversed, so long as the read - // comes before the write. - - MemDependencyChecker analyzer; - analyzer.allowLoopExecutionOrderAnalysis(); - - StmtPtr stmt = For::make( - x, - 3, - 10, - Store::make( - a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x}))); - stmt->accept(&analyzer); - - // However here was can determine the A store is earlier in the order than - // the load. - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 9; x++) { - * A[8 - x] = A[9 - x]; - * } - */ - - // But not if it doesn't. - - MemDependencyChecker analyzer; - analyzer.allowLoopExecutionOrderAnalysis(); - - StmtPtr stmt = For::make( - x, - 3, - 10, - Store::make( - a, {ExprHandle(8) - x}, Load::make(a, {ExprHandle(9) - x}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 9; x++) { - * A[9 - x] = A[8 - x]; - * } - */ - - // And not if we're not relying on execution order. - - MemDependencyChecker analyzer; - - StmtPtr stmt = For::make( - x, - 3, - 10, - Store::make( - a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 3; x < 10; x++) { - * A[x - 2] = A[x - 1]; - * } - */ - - // Forward order but negative indices. - - MemDependencyChecker analyzer; - analyzer.allowLoopExecutionOrderAnalysis(); - - StmtPtr stmt = - For::make(x, 3, 10, Store::make(a, {x - 2}, Load::make(a, {x - 1}))); - stmt->accept(&analyzer); - - // However here was can determine the A store is earlier in the order than - // the load. - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 2]; - * } - */ - - // With an access stride. - - MemDependencyChecker analyzer; - // Execution order doesn't matter since the read and the write are totally - // distinct. - - StmtPtr stmt = - For::make(x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2}))); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 2 + 1]; - * } - */ - - // Here we can use the common stride of the accesses to determine they are - // distinct. - // Note, this is the only place (loop self dependency) we use this stride - // to avoid unnecessary dependence. - - MemDependencyChecker analyzer; - // Execution order doesn't matter since the read and the write are totally - // distinct. - - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 1}))); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 2 - 1]; - * } - */ - - // same if the read is behind the write so long as they are distinct. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 1, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 - 1}))); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 2 + 2]; - * } - */ - - // But not if the offset is in the stride. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 2}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 2 - 2]; - * } - */ - - // Works with negative offsets too. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 1, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 - 2}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 2 + 7]; - * } - */ - - // Detects accesses are distinct when offset is large but not a multiple - // of stride. - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 7}))); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 2 + 4]; - * } - */ - - // Works with offsets which are multiples of the stride. - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 4}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 6] = A[x * 6 + 5]; - * } - */ - - // detects accesses are distinct with large strides when the offset is - // within. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x * 6}, Load::make(a, {x * 6 + 5}))); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 6]; - * } - */ - - // detects accesses are overlapping when stride is different but a - // multiple. - - MemDependencyChecker analyzer; - StmtPtr stmt = - For::make(x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 4] = A[x * 2]; - * } - */ - - // still works when the read axis is the smaller stride. - - MemDependencyChecker analyzer; - StmtPtr stmt = - For::make(x, 0, 10, Store::make(a, {x * 4}, Load::make(a, {x * 2}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 6 + 1]; - * } - */ - - // detects accesses are distinct when stride is different but a multiple - // and there is an offset. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6 + 1}))); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 6 + 4]; - * } - */ - - // The smaller stride determines whether there is overlap. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6 + 4}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2 + 3] = A[x * 6]; - * } - */ - - // The smaller stride determines whether there is overlap, not the larger. - - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x * 2 + 3}, Load::make(a, {x * 6}))); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[x * 3 + 1]; - * } - */ - - // If they have strides with no common multiple > 1, they overlap. - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 3 + 1}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x] = A[x + 10]; - * } - */ - - // If the offset is greater than the size of the loop, they can't overlap. - - MemDependencyChecker analyzer; - StmtPtr stmt = - For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 10}))); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x] = A[9 - x]; - * } - */ - - // If they have different execution orders they may overlap. - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x}, Load::make(a, {ExprHandle(9) - x}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x * 2] = A[19 - x * 2]; - * } - */ - - // Or they may not, depending on their start offset and strides. - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, - 0, - 10, - Store::make(a, {x * 2}, Load::make(a, {ExprHandle(19) - x * 2}))); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x / 2] = A[x / 2]; - * } - */ - - // If the stride is not monotonic, they overlap. - - MemDependencyChecker analyzer; - StmtPtr stmt = - For::make(x, 0, 10, Store::make(a, {x / 2}, Load::make(a, {x / 2}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x / 2] = A[x / 2] + 1; - * } - */ - - // If the stride is not monotonic, they overlap - even with an offset. - MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, 0, 10, Store::make(a, {x / 2}, Load::make(a, {x / 2 + 1}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = 0; x < 10; x++) { - * A[x % 2] = A[x % 2]; - * } - */ - - // Mod too... - - analysis::MemDependencyChecker analyzer; - StmtPtr stmt = For::make( - x, - 0, - 10, - Store::make(a, {Mod::make(x, 2)}, Load::make(a, {Mod::make(x, 2)}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - /* for (int x = y; x < z; x++) { - * A[x] = A[x + 1]; - * } - */ - - // Still works with symbolic loop extents. - - { - MemDependencyChecker analyzer; - StmtPtr stmt = - For::make(x, y, z, Store::make(a, {x}, Load::make(a, {x + 1}))); - stmt->accept(&analyzer); - - ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); - } - - { - MemDependencyChecker analyzer; - analyzer.allowLoopExecutionOrderAnalysis(); - StmtPtr stmt = - For::make(x, y, z, Store::make(a, {x}, Load::make(a, {x + 1}))); - stmt->accept(&analyzer); - - ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); - } - } -} - -// Verify that a strided access still works. -// TODO: actually this only works because of the size of the ranges, revisit -// this test after strided overlap is implemented. -TEST(MemDependency, MemDependencyCheckerLoopDistinctStrides) { - BufHandle a("A", {20}, kInt); - BufHandle b("B", {20}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - using namespace analysis; - MemDependencyChecker analyzer({a.node()}, {b.node()}); - StmtPtr stmt = Block::make( - {For::make( - x, 0, 10, Store::make(b, {x * 2 + 1}, Load::make(a, {x * 2 + 1}))), - For::make(x, 0, 10, Store::make(b, {x * 2}, Load::make(a, {x * 2}))) - - }); - stmt->accept(&analyzer); - - // Sanity check output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); - - // Output has 2 dependencies... the store in each loop. - auto outputAccess = analyzer.output(b.node()); - ASSERT_EQ(outputAccess->dependencies().size(), 2); -} - -/* TODO(nickg) - this test will fail due to the lack of stride math in Bound -TEST(MemDependency, MemDependencyCheckerLoopDistinctStrides) { - BufHandle a("A", {20}, kInt); - BufHandle b("B", {20}, kInt); - BufHandle c("C", {10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - { - analysis::MemDependencyChecker analyzer({a.node()}, {c.node()}); - StmtPtr stmt = Block::make( - {For::make( - x, - 0, - 10, - Store::make(b, {x * 2 + 1}, Load::make(a, {x * 2 + 1}))), - For::make( - x, 0, 10, Store::make(b, {x * 2}, Load::make(a, {x * 2}))), - For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x}))) - - }); - stmt->accept(&analyzer); - - std::cout << *stmt << "\n"; - for (auto& wi : analyzer.getHistory()) { - wi->print(); - } - } -}*/ - -// analysis on Stmts using Cond. -TEST(MemDependency, MemDependencyCheckerLoopBoundsCond) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - BufHandle c("C", {10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - using namespace analysis; - - { - /* for (int x = 0; x < 10; x++) { - * C[x] = A[x]; - * } - * if (y<5 ? 1 : 0) { - * C[0] = (B[0]) + 1; - * } else { - * C[0] = (B[1]) + 1; - * } - */ - - // Future usages may depend on accesses in both branches of a condition. - - MemDependencyChecker analyzer({a, b}, {c}); - StmtPtr stmt = Block::make( - {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), - Cond::make( - CompareSelect::make(y, 5, CompareSelectOperation::kLT), - Store::make(c, {0}, Add::make(Load::make(b, {0}), 1)), - Store::make(c, {0}, Add::make(Load::make(b, {1}), 1)))}); - - stmt->accept(&analyzer); - - // Output C should have 3 dependencies, each of the three stores. - auto outputAccess = analyzer.output(c.node()); - ASSERT_NE(outputAccess, nullptr); - ASSERT_EQ(outputAccess->dependencies().size(), 3); - - // C depends indirectly on A and B. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - } - - { - /* for (int x = 0; x < 10; x++) { - * C[x] = A[x]; - * } - * if (y<5 ? 1 : 0) { - * for (int x = 0; x < 10; x++) { - * C[x] = B[x]; - * } - * } else { - * for (int x = 0; x < 10; x++) { - * C[x] = (B[x]) + 1; - * } - * } - */ - - // Future usages may depend on accesses in both branches of a condition. - - MemDependencyChecker analyzer({a, b}, {c}); - StmtPtr stmt = Block::make( - {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), - Cond::make( - CompareSelect::make(y, 5, CompareSelectOperation::kLT), - For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x}))), - For::make( - x, - 0, - 10, - Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))))}); - - stmt->accept(&analyzer); - - // Output C should have 3 dependencies, each of the three stores. - auto outputAccess = analyzer.output(c.node()); - ASSERT_NE(outputAccess, nullptr); - ASSERT_EQ(outputAccess->dependencies().size(), 3); - - // TODO(nickg): actually since the true and false branch cover the total - // range of the first store this should have 2 dependencies, but we don't - // do that yet. - - // C depends indirectly on A and B. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - } - - { - /* for (int x = 0; x < 10; x++) { - * C[x] = A[x]; - * } - * if (y<5 ? 1 : 0) { - * for (int x = 0; x < 10; x++) { - * C[x] = (B[x]) + 1; - * } - * } - */ - - // Only has true branch. - - MemDependencyChecker analyzer({a, b}, {c}); - StmtPtr stmt = Block::make( - {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), - Cond::make( - CompareSelect::make(y, 5, CompareSelectOperation::kLT), - For::make( - x, - 0, - 10, - Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))), - nullptr)}); - - stmt->accept(&analyzer); - - // Output C should have 3 dependencies, each of the three stores. - auto outputAccess = analyzer.output(c.node()); - ASSERT_NE(outputAccess, nullptr); - ASSERT_EQ(outputAccess->dependencies().size(), 2); - - // C depends indirectly on A and B. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - } - - { - /* for (int x = 0; x < 10; x++) { - * C[x] = A[x]; - * } - * if (y<5 ? 1 : 0) { - * } else { - * for (int x = 0; x < 10; x++) { - * C[x] = (B[x]) + 1; - * } - * } - */ - - // Only has false branch. - - MemDependencyChecker analyzer({a, b}, {c}); - StmtPtr stmt = Block::make( - {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), - Cond::make( - CompareSelect::make(y, 5, CompareSelectOperation::kLT), - nullptr, - For::make( - x, - 0, - 10, - Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))))}); - - stmt->accept(&analyzer); - - // Output C should have 3 dependencies, each of the three stores. - auto outputAccess = analyzer.output(c.node()); - ASSERT_NE(outputAccess, nullptr); - ASSERT_EQ(outputAccess->dependencies().size(), 2); - - // C depends indirectly on A and B. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - } - - { - /* for (int x = 0; x < 10; x++) { - * C[x] = A[x]; - * } - * if (C[0]<5 ? 1 : 0) { - * C[0] = 5; - * } - */ - - // Cond's Condition depends on a previous access. - - MemDependencyChecker analyzer({a}, {c}); - StorePtr initStore = Store::make(c, {x}, Load::make(a, {x})); - ExprHandle conditionalLoad = Load::make(c, {0}); - StmtPtr stmt = Block::make( - {For::make(x, 0, 10, initStore), - Cond::make( - CompareSelect::make( - conditionalLoad, 5, CompareSelectOperation::kLT), - Store::make(c, {0}, 5), - nullptr)}); - - stmt->accept(&analyzer); - - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - - ASSERT_TRUE(analyzer.dependsDirectly(conditionalLoad.node(), initStore)); - ASSERT_FALSE(analyzer.dependsDirectly(conditionalLoad.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(conditionalLoad.node(), a.node())); - } -} - -// Stmts using IfThenElse. -TEST(MemDependency, MemDependencyCheckerIfThenElse) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - BufHandle c("C", {10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - using namespace analysis; - - { - /* for (int x = 0; x < 10; x++) { - * C[x] = A[x]; - * } - * C[0] = (y < 5 ? (B[0]) + 1 : (B[1]) + 1; - */ - - // Future usages may depend on accesses in both branches of a condition. - - MemDependencyChecker analyzer({a, b}, {c}); - StorePtr ifStore = Store::make( - c, - {0}, - IfThenElse::make( - CompareSelect::make(y, 5, CompareSelectOperation::kLT), - Add::make(Load::make(b, {0}), 1), - Add::make(Load::make(b, {1}), 1))); - StmtPtr stmt = Block::make( - {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), - ifStore}); - - stmt->accept(&analyzer); - - // Output C should have 2 dependencies, each of the two stores. - auto outputAccess = analyzer.output(c.node()); - ASSERT_NE(outputAccess, nullptr); - ASSERT_EQ(outputAccess->dependencies().size(), 2); - - // Now we need to check the Store containing the IfThenElse. - auto ifStoreAccess = analyzer.accessFor(ifStore); - - // It should have 2 dependencies. - ASSERT_EQ(ifStoreAccess->dependencies().size(), 2); - - // C depends indirectly on A and B. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - } - - { - /* for (int x = 0; x < 10; x++) { - * C[x] = A[x]; - * } - * C[0] = (y < 5 ? (B[0]) + 1 : 42; - */ - - // If the load appears in only one side of an IfThenElse the output may be - // dependent on it. - - MemDependencyChecker analyzer({a, b}, {c}); - StorePtr ifStore = Store::make( - c, - {0}, - IfThenElse::make( - CompareSelect::make(y, 5, CompareSelectOperation::kLT), - Add::make(Load::make(b, {0}), 1), - 42)); - StmtPtr stmt = Block::make( - {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), - ifStore}); - - stmt->accept(&analyzer); - - // C depends indirectly on A and B. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - } - - { - /* for (int x = 0; x < 10; x++) { - * C[x] = (x < 5 ? B[x] : A[x]; - * } - */ - - // In this case C is dependent on both A and B. - - // TODO: in cases like this it would be possible to split the range of B - // into two bounds, one dependent on A and one dependent on B. We'd need to - // examine conditions relative to previously encountered loop variables. I'm - // uncertain if this would be helpful. - - MemDependencyChecker analyzer({a, b}, {c}); - StorePtr ifStore = Store::make( - c, - {0}, - IfThenElse::make( - CompareSelect::make(y, 5, CompareSelectOperation::kLT), - Load::make(b, {x}), - Load::make(a, {x}))); - StmtPtr stmt = Block::make({For::make(x, 0, 10, ifStore)}); - - stmt->accept(&analyzer); - - // C depends indirectly on A and B. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - } -} - -// Cutting a loop with single elem writes -TEST(MemDependency, MemDependencyCheckerCutLoop) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - - using namespace analysis; - - { - /* for (int x = 0; x < 10; x++) { - * B[x] = A[x]; - * } - * B[5] = 100; - */ - - // Cutting a loop with single element writes. - - MemDependencyChecker analyzer({a}, {b}); - StmtPtr stmt = Block::make( - {For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x}))), - Store::make(b, {5}, 100)}); - - stmt->accept(&analyzer); - - // Output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); - - // Output has 2 dependencies. - auto outputAccess = analyzer.output(b.node()); - ASSERT_NE(outputAccess, nullptr); - ASSERT_EQ(outputAccess->dependencies().size(), 2); - } - - { - /* for (int x = 0; x < 10; x++) { - * B[x] = A[x]; - * } - * for (int x = 4; x < 7; x++) { - * B[x] = B[x] + 3; - * } - * B[5] = 100; - * B[6] = 101; - * B[7] = 102; - */ - - // Cutting a loop with a smaller loop but then totally overlap that second - // loop with one element writes. - - MemDependencyChecker analyzer({a}, {b}); - ForPtr firstLoop = - For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x}))); - StorePtr secondStore = - Store::make(b, {x}, Add::make(Load::make(b, {x}), 1)); - ForPtr secondLoop = For::make(x, 4, 7, secondStore); - - StmtPtr stmt = Block::make( - {firstLoop, - secondLoop, - Store::make(b, {4}, 100), - Store::make(b, {5}, 101), - Store::make(b, {6}, 102)}); - - stmt->accept(&analyzer); - - // Output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); - - // Output has 4 dependencies. - auto outputAccess = analyzer.output(b.node()); - ASSERT_NE(outputAccess, nullptr); - ASSERT_EQ(outputAccess->dependencies().size(), 4); - - // Second loop depends on first loop. - ASSERT_TRUE(analyzer.dependsDirectly(secondLoop, firstLoop)); - - // Output does not depend on second loop or store. - ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), secondLoop)); - ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), secondStore)); - } -} - -// Dynamic shapes (load in indices). -TEST(MemDependency, MemDependencyCheckerDynamicShapes) { - BufHandle a("A", {100}, kInt); - BufHandle b("B", {100}, kInt); - BufHandle c("C", {100}, kInt); - VarHandle x("x", kInt); - - using namespace analysis; - - auto CB = [](ExprHandle s, ExprHandle e) { - return Bound(s.node(), e.node()); - }; - - auto EQ = [](const IndexBounds& x, const IndexBounds& y) { - return indexBoundsEquals(x, y); - }; - - { - /* for (int x = 0; x < B[0]; x++) { - * C[x] = A[x]; - * } - */ - MemDependencyChecker analyzer({a, b}, {c}); - StmtPtr stmt = Block::make({For::make( - x, 0, Load::make(b, {0}), Store::make(c, {x}, Load::make(a, {x})))}); - - stmt->accept(&analyzer); - - /* 0. Input: B[(0, 99)] - dependents: 2 - * 1. Input: A[(0, 99)] - dependents: 3 - * 2. Load: B[(0, 0)] - depends on: 0 - dependents: 3 4 - * 3. Load: A[(0, (B[0]) - 1)] - depends on: 1 2 - dependents: 4 - * 4. Store: C[(0, (B[0]) - 1)] - depends on: 2 3 - dependents: 5 - * 5. Output: C[(0, 99)] - depends on: 4 - */ - - // Output dependent on A input. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - // Also dependent on B input to determine the size of the region written. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 6); - - // The accesses in the loop depend on the load in the stop condition. - ASSERT_TRUE(history[4]->hasDependency(history[2])); - ASSERT_TRUE(history[3]->hasDependency(history[2])); - - // Make a load from B to compare against. - ExprHandle loadFromB = Load::make(b, {0}); - - ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, loadFromB - 1)})); - ASSERT_TRUE(EQ(history[4]->bounds(), {CB(0, loadFromB - 1)})); - } - - { - /* for (int x = B[0]; x < B[1]; x++) { - * C[x] = A[x]; - * } - */ - MemDependencyChecker analyzer({a, b}, {c}); - StmtPtr stmt = Block::make({For::make( - x, - Load::make(b, {0}), - Load::make(b, {1}), - Store::make(c, {x}, Load::make(a, {x})))}); - - stmt->accept(&analyzer); - - /* 0. Input: B[(0, 99)] - dependents: 2 3 - * 1. Input: A[(0, 99)] - dependents: 4 - * 2. Load: B[(0, 0)] - depends on: 0 - dependents: 4 5 - * 3. Load: B[(1, 1)] - depends on: 0 - dependents: 4 5 - * 4. Load: A[(B[0], (B[1]) - 1)] - depends on: 1 2 3 - dependents: 5 - * 5. Store: C[(B[0], (B[1]) - 1)] - depends on: 2 3 4 - dependents: 6 - * 6. Output: C[(0, 99)] - depends on: 5 - */ - - // Sanity check output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 7); - - // The accesses in the loop depend on the load in the start condition. - ASSERT_TRUE(history[5]->hasDependency(history[2])); - ASSERT_TRUE(history[4]->hasDependency(history[2])); - - // also the stop condition. - ASSERT_TRUE(history[5]->hasDependency(history[3])); - ASSERT_TRUE(history[4]->hasDependency(history[3])); - - // Make loads from B to compare against. - ExprHandle loadFromB0 = Load::make(b, {0}); - ExprHandle loadFromB1 = Load::make(b, {1}); - ASSERT_TRUE(EQ(history[4]->bounds(), {CB(loadFromB0, loadFromB1 - 1)})); - ASSERT_TRUE(EQ(history[5]->bounds(), {CB(loadFromB0, loadFromB1 - 1)})); - } - - { - /* for (int x = 0; x < 10; x++) { - * C[x] = A[B[x]]; - * } - */ - MemDependencyChecker analyzer({a, b}, {c}); - StmtPtr stmt = Block::make({For::make( - x, 0, 10, Store::make(c, {x}, Load::make(a, {Load::make(b, {x})})))}); - - stmt->accept(&analyzer); - - /* 0. Input: B[(0, 99)] - dependents: 2 - * 1. Input: A[(0, 99)] - dependents: 3 - * 2. Load: B[(0, 9)] - depends on: 0 - dependents: 3 4 - * 3. Load: A[(B[0], B[9])] - depends on: 1 2 - dependents: 4 - * 4. Store: C[(0, 9)] - depends on: 2 3 - dependents: 5 - * 5. Output: C[(0, 99)] - depends on: 4 - */ - - // Sanity check output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 6); - - // The store depends on both loads, the load of A depends on the load of B. - ASSERT_TRUE(history[4]->hasDependency(history[2])); - ASSERT_TRUE(history[4]->hasDependency(history[3])); - - ASSERT_TRUE(history[3]->hasDependency(history[2])); - - // The loads in the indices depend on the relevant input buffer. - ASSERT_TRUE(history[3]->hasDependency(history[1])); - ASSERT_TRUE(history[2]->hasDependency(history[0])); - - // The load from B has the loop bounds. - ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)})); - - // The load from A has bounds B[0] to B[9]. - ExprHandle loadFromB0 = Load::make(b, {0}); - ExprHandle loadFromB9 = Load::make(b, {9}); - ASSERT_TRUE(EQ(history[3]->bounds(), {CB(loadFromB0, loadFromB9)})); - } - - { - /* for (int x = 0; x < 10; x++) { - * C[B[x]] = A[x]; - * } - */ - MemDependencyChecker analyzer({a, b}, {c}); - StmtPtr stmt = Block::make({For::make( - x, 0, 10, Store::make(c, {Load::make(b, {x})}, Load::make(a, {x})))}); - - stmt->accept(&analyzer); - - /* 0. Input: B[(0, 99)] - dependents: 3 - * 1. Input: A[(0, 99)] - dependents: 2 - * 2. Load: A[(0, 9)] - depends on: 1 - dependents: 4 - * 3. Load: B[(0, 9)] - depends on: 0 - dependents: 4 - * 4. Store: C[(B[0], B[9])] - depends on: 2 3 - dependents: 5 - * 5. Output: C[(0, 99)] - depends on: 4 - */ - // Sanity check output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 6); - - // The store depends on both loads, neither load is dependent. - ASSERT_TRUE(history[4]->hasDependency(history[2])); - ASSERT_TRUE(history[4]->hasDependency(history[3])); - - ASSERT_FALSE(history[3]->hasDependency(history[2])); - ASSERT_FALSE(history[2]->hasDependency(history[3])); - - // The loads each depend on their relevant input. (but accesses are in a - // different order than the last case). - ASSERT_TRUE(history[3]->hasDependency(history[0])); - ASSERT_TRUE(history[2]->hasDependency(history[1])); - - // The load from B has the loop bounds. - ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, 9)})); - - // And so does the load from A. - ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)})); - } - - { - /* for (int x = 0; x < 10; x++) { - * C[B[A[x]]] = x; - * } - */ - MemDependencyChecker analyzer({a, b}, {c}); - StmtPtr stmt = Block::make({For::make( - x, 0, 10, Store::make(c, {Load::make(b, {Load::make(a, {x})})}, x))}); - - stmt->accept(&analyzer); - - /* 0. Input: B[(0, 99)] - dependents: 3 - * 1. Input: A[(0, 99)] - dependents: 2 - * 2. Load: A[(0, 9)] - depends on: 1 - dependents: 3 4 - * 3. Load: B[(A[0], A[9])] - depends on: 0 2 - dependents: 4 - * 4. Store: C[(B[A[0]], B[A[9]])] - depends on: 2 3 - dependents: 5 - * 5. Output: C[(0, 99)] - depends on: 4 - */ - - // Sanity check output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); - - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 6); - - // The store depends on both loads. - ASSERT_TRUE(history[4]->hasDependency(history[2])); - ASSERT_TRUE(history[4]->hasDependency(history[3])); - - // The outer load depends on the inner. - ASSERT_TRUE(history[3]->hasDependency(history[2])); - - // The loads each depend on their relevant input. (but accesses are in a - // different order than the last case). - ASSERT_TRUE(history[3]->hasDependency(history[0])); - ASSERT_TRUE(history[2]->hasDependency(history[1])); - - // The load from A has the loop bounds. - ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)})); - // The load from B as bounds A[0] to A[9]. - ExprHandle loadFromA0 = Load::make(a, {0}); - ExprHandle loadFromA9 = Load::make(a, {9}); - ASSERT_TRUE(EQ(history[3]->bounds(), {CB(loadFromA0, loadFromA9)})); - - // The store has bounds of B[A[0]] to B[A[9]]. - ExprHandle loadFromBA0 = Load::make(b, {loadFromA0}); - ExprHandle loadFromBA9 = Load::make(b, {loadFromA9}); - ASSERT_TRUE(EQ(history[4]->bounds(), {CB(loadFromBA0, loadFromBA9)})); - } -} - -// Verify multi dimensional bounds work. -TEST(MemDependency, MemDependencyCheckerMultiDim) { - int M = 10, N = 9, K = 12; - BufHandle a("A", {M, N, K}, kInt); - BufHandle b("B", {M, N, K}, kInt); - BufHandle c("C", {M, K}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - - using namespace analysis; - - auto CB = [](ExprHandle s, ExprHandle e) { - return Bound(s.node(), e.node()); - }; - - auto EQ = [](const IndexBounds& x, const IndexBounds& y) { - return indexBoundsEquals(x, y); - }; - - { - /* for (int x = 0; x < 10; x++) { - * for (int y = 0; y < 9; y++) { - * for (int z = 0; z < 12; z++) { - * B[x, y, z] = A[x, y, z]; - * } - * } - * } - */ - // Full range. - - MemDependencyChecker analyzer({a}, {b}); - StmtPtr stmt = Block::make({For::make( - x, - 0, - M, - For::make( - y, - 0, - N, - For::make( - z, - 0, - K, - Store::make(b, {x, y, z}, Load::make(a, {x, y, z})))))}); - - stmt->accept(&analyzer); - - // Sanity test: Output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); - - // 4 accesses: input, load, store, output. - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 4); - - // Simple chain from input to output. - ASSERT_TRUE(history[3]->hasDependency(history[2])); - ASSERT_TRUE(history[2]->hasDependency(history[1])); - ASSERT_TRUE(history[1]->hasDependency(history[0])); - - ASSERT_TRUE( - EQ(history[1]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)})); - ASSERT_TRUE( - EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)})); - } - - { - /* for (int x = 0; x < 5; x++) { - * for (int y = 0; y < 5; y++) { - * for (int z = 0; z < 5; z++) { - * B[x, y, z] = A[x, y, z]; - * } - * } - * } - */ - // Partial range. - - MemDependencyChecker analyzer({a}, {b}); - StmtPtr stmt = Block::make({For::make( - x, - 0, - 5, - For::make( - y, - 0, - 5, - For::make( - z, - 0, - 5, - Store::make(b, {x, y, z}, Load::make(a, {x, y, z})))))}); - - stmt->accept(&analyzer); - - // Sanity test: Output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); - - // 4 accesses: input, load, store, output. - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 4); - - // Simple chain from input to output. - ASSERT_TRUE(history[3]->hasDependency(history[2])); - ASSERT_TRUE(history[2]->hasDependency(history[1])); - ASSERT_TRUE(history[1]->hasDependency(history[0])); - - ASSERT_TRUE(EQ(history[1]->bounds(), {CB(0, 4), CB(0, 4), CB(0, 4)})); - ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 4), CB(0, 4), CB(0, 4)})); - } - - { - /* for (int x = 0; x < 10; x++) { - * for (int y = 0; y < 12; y++) { - * B[x, 0, y] = A[x, 0, y]; - * } - * } - */ - - // Partial loops. - - MemDependencyChecker analyzer({a}, {b}); - StmtPtr stmt = Block::make({For::make( - x, - 0, - N, - For::make( - y, 0, K, Store::make(b, {x, 0, y}, Load::make(a, {x, 0, y}))))}); - - stmt->accept(&analyzer); - - // Sanity test: Output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); - - // 4 accesses: input, load, store, output. - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 4); - - // Simple chain from input to output. - ASSERT_TRUE(history[3]->hasDependency(history[2])); - ASSERT_TRUE(history[2]->hasDependency(history[1])); - ASSERT_TRUE(history[1]->hasDependency(history[0])); - - ASSERT_TRUE( - EQ(history[1]->bounds(), {CB(0, N - 1), CB(0, 0), CB(0, K - 1)})); - ASSERT_TRUE( - EQ(history[2]->bounds(), {CB(0, N - 1), CB(0, 0), CB(0, K - 1)})); - } - - { - /* for (int x = 0; x < 10; x++) { - * for (int y = 0; y < 100; y++) { - * for (int z = 0; z < 12; z++) { - * B[x, 0, z] = (A[x, 0, z]) + (C[x, z]); - * } - * } - * } - */ - - // Loops that don't correspond to an index, bufs with different - // dimensionality. - - MemDependencyChecker analyzer({a, c}, {b}); - StmtPtr stmt = Block::make({For::make( - x, - 0, - M, - For::make( - y, - 0, - 100, - For::make( - z, - 0, - K, - Store::make( - b, - {x, 0, z}, - Add::make( - Load::make(a, {x, 0, z}), Load::make(c, {x, z}))))))}); - - stmt->accept(&analyzer); - - // Sanity test: Output depends on both inputs. - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), c.node())); - - // 6 accesses: 2 inputs, 2 loads, store, output. - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 6); - - // Simple chain from input to output over the A buf. - // history[0] is the C input, history[3] is the load from C. - ASSERT_TRUE(history[5]->hasDependency(history[4])); - ASSERT_TRUE(history[4]->hasDependency(history[2])); - ASSERT_TRUE(history[2]->hasDependency(history[1])); - // The store also depends on the load from the C input. - ASSERT_TRUE(history[4]->hasDependency(history[3])); - ASSERT_TRUE(history[3]->hasDependency(history[0])); - - // A Buf accesses. - ASSERT_TRUE( - EQ(history[4]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, K - 1)})); - ASSERT_TRUE( - EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, K - 1)})); - - // C buf access. - ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, M - 1), CB(0, K - 1)})); - } - - { - /* for (int x = 0; x < 9; x++) { - * for (int y = 0; y < 10; y++) { - * for (int z = 0; z < 12; z++) { - * B[x, 0, 0] = (B[x, y, z]) + (A[x, y, z]); - * } - * } - * } - */ - // Multi-dim reductions. - - MemDependencyChecker analyzer({a}, {b}); - StmtPtr stmt = Block::make({For::make( - x, - 0, - M, - For::make( - y, - 0, - N, - For::make( - z, - 0, - K, - Store::make( - b, - {x, 0, 0}, - Add::make( - Load::make(b, {x, y, z}), - Load::make(a, {x, y, z}))))))}); - - stmt->accept(&analyzer); - - // Sanity test: Output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); - - // 4 accesses: input, 2 loads, store, output. - auto history = analyzer.getHistory(); - ASSERT_EQ(history.size(), 5); - - // Simple chain from input to output. - ASSERT_TRUE(history[4]->hasDependency(history[3])); - ASSERT_TRUE(history[3]->hasDependency(history[2])); - ASSERT_TRUE(history[3]->hasDependency(history[1])); - ASSERT_TRUE(history[2]->hasDependency(history[0])); - - // The load from B depends on the store to B. - ASSERT_TRUE(history[1]->hasDependency(history[3])); - - ASSERT_TRUE( - EQ(history[1]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)})); - ASSERT_TRUE( - EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)})); - ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, 0)})); - } -} - -// Various tests using the external Compute/Reduce API. -TEST(MemDependency, MemDependencyCheckerComputeAPI) { - using namespace analysis; - - /* for (int m = 0; m < 4; m++) { - * for (int n = 0; n < 5; n++) { - * for (int k = 0; k < 6; k++) { - * broadcast_add[m, n, k] = (a[m, n]) + (b[n, k]); - * } - * } - * } - * for (int m_1 = 0; m_1 < 4; m_1++) { - * for (int n_1 = 0; n_1 < 5; n_1++) { - * for (int k_1 = 0; k_1 < 6; k_1++) { - * d[m_1, n_1, k_1] = (broadcast_add(m_1, n_1, k_1)) + float(1); - * } - * } - * } - */ - - // Can determine if 2 loops created by Compute are dependent. - BufHandle a_buf("a", {4, 5}, kFloat); - BufHandle b_buf("b", {5, 6}, kFloat); - Tensor c = Compute( - "broadcast_add", - {4, 5, 6}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) + b_buf.load(n, k); - }); - Tensor d = Compute( - "d", - {4, 5, 6}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return c.load(m, n, k) + 1; - }); - - LoopNest l({d}, {c, d}); - - MemDependencyChecker analyzer({a_buf.node(), b_buf.node()}, {d.buf()}); - - l.root_stmt()->accept(&analyzer); - - // Sanity test: Output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a_buf.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b_buf.node())); - - // Second loop depends on first loop. - auto c_loop = l.getLoopStmtsFor(c)[0]; - auto d_loop = l.getLoopStmtsFor(d)[0]; - ASSERT_TRUE(analyzer.dependsDirectly(d_loop, c_loop)); -} - -TEST(MemDependency, MemDependencyCheckerComputeInline) { - using namespace analysis; - - /* for (int m = 0; m < 4; m++) { - * for (int n = 0; n < 5; n++) { - * for (int k = 0; k < 6; k++) { - * d[m, n, k] = ((a[m, n]) + (b[n, k])) + float(1); - * } - * } - * } - */ - - // Check inlining affects the number of accesses returned. - - BufHandle a_buf("a", {4, 5}, kFloat); - BufHandle b_buf("b", {5, 6}, kFloat); - Tensor c = Compute( - "broadcast_add", - {4, 5, 6}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) + b_buf.load(n, k); - }); - Tensor d = Compute( - "d", - {4, 5, 6}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return c.load(m, n, k) + 1; - }); - - LoopNest l({d}, {c, d}); - l.computeInline(c.buf()); - - MemDependencyChecker analyzer({a_buf.node(), b_buf.node()}, {d.buf()}); - l.root_stmt()->accept(&analyzer); - - // Sanity test: Output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a_buf.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b_buf.node())); - - // broadcast_add tensor should not appear in trace at all. - for (auto& wi : analyzer.getHistory()) { - ASSERT_NE(wi->var(), c.buf()->base_handle()); - } -} - -TEST(MemDependency, MemDependencyCheckerComputeSplit) { - using namespace analysis; - // Split an axis, so the number of loops != the number of dimensions. - - BufHandle a_buf("a", {4, 5}, kFloat); - BufHandle b_buf("b", {5, 6}, kFloat); - Tensor c = Compute( - "broadcast_add", - {4, 5, 6}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) + b_buf.load(n, k); - }); - - LoopNest l({c}); - - MemDependencyChecker analyzer_before({a_buf.node(), b_buf.node()}, {c.buf()}); - l.root_stmt()->accept(&analyzer_before); - - l.splitWithTail(l.getLoopStmtsFor(c)[0], 2); - - MemDependencyChecker analyzer_after({a_buf.node(), b_buf.node()}, {c.buf()}); - StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); - stmt->accept(&analyzer_after); - - // Splitting should not change accesses at all. - auto history_before = analyzer_before.getHistory(); - auto history_after = analyzer_after.getHistory(); - - ASSERT_EQ(history_before.size(), history_after.size()); - - for (size_t i = 0; i < history_before.size(); ++i) { - ASSERT_EQ(history_before[i]->type(), history_after[i]->type()); - ASSERT_EQ(history_before[i]->var(), history_after[i]->var()); - ASSERT_EQ( - history_before[i]->bounds().size(), history_after[i]->bounds().size()); - ASSERT_TRUE(indexBoundsEquals( - history_before[i]->bounds(), history_after[i]->bounds())); - ASSERT_EQ( - history_before[i]->dependencies().size(), - history_after[i]->dependencies().size()); - ASSERT_EQ( - history_before[i]->dependents().size(), - history_after[i]->dependents().size()); - } -} - -TEST(MemDependency, MemDependencyCheckerComputeReorder) { - using namespace analysis; - // Reorder an axis, so the loop order doesn't match the indexing order. - - BufHandle a_buf("a", {4, 5}, kFloat); - BufHandle b_buf("b", {5, 6}, kFloat); - Tensor c = Compute( - "broadcast_add", - {4, 5, 6}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n) + b_buf.load(n, k); - }); - - LoopNest l({c}); - - MemDependencyChecker analyzer_before({a_buf.node(), b_buf.node()}, {c.buf()}); - l.root_stmt()->accept(&analyzer_before); - - auto loops = l.getLoopStmtsFor(c); - l.reorderAxis(loops[0], loops[1]); - - MemDependencyChecker analyzer_after({a_buf.node(), b_buf.node()}, {c.buf()}); - StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); - stmt->accept(&analyzer_after); - - // Reordering should not change accesses at all. - auto history_before = analyzer_before.getHistory(); - auto history_after = analyzer_after.getHistory(); - - ASSERT_EQ(history_before.size(), history_after.size()); - - for (size_t i = 0; i < history_before.size(); ++i) { - ASSERT_EQ(history_before[i]->type(), history_after[i]->type()); - ASSERT_EQ(history_before[i]->var(), history_after[i]->var()); - ASSERT_EQ( - history_before[i]->bounds().size(), history_after[i]->bounds().size()); - ASSERT_TRUE(indexBoundsEquals( - history_before[i]->bounds(), history_after[i]->bounds())); - ASSERT_EQ( - history_before[i]->dependencies().size(), - history_after[i]->dependencies().size()); - ASSERT_EQ( - history_before[i]->dependents().size(), - history_after[i]->dependents().size()); - } -} - -TEST(MemDependency, MemDependencyCheckerComputeReduce) { - using namespace analysis; - /* for (int l2 = 0; l2 < 2; l2++) { - * for (int n1 = 0; n1 < 3; n1++) { - * for (int m1 = 0; m1 < 6; m1++) { - * scale[l2, n1, m1] = (b[l2, n1, m1]) * (a[l2, n1, m1]); - * } - * } - * } - * for (int l1 = 0; l1 < 2; l1++) { - * sum[l1] = float(0); - * for (int n1_1 = 0; n1_1 < 3; n1_1++) { - * for (int m1_1 = 0; m1_1 < 6; m1_1++) { - * sum[l1] = ReduceOp(sum, (sum[l1]) + (scale(l1, n1_1, m1_1)), - * out_args={l1}, reduce_args={n1, m1}); - * } - * } - * } - */ - - // Can determine dependencies of a Reduction. - - BufHandle a("a", {2, 3, 6}, kFloat); - BufHandle b("b", {2, 3, 6}, kFloat); - - Tensor c = Compute( - "scale", - {2, 3, 6}, - [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { - return b.load(l, n, m) * a.load(l, n, m); - }); - Tensor d = Reduce("sum", {2}, Sum(), c, {3, 6}); - LoopNest l({d}, {c, d}); - - MemDependencyChecker analyzer({a.node(), b.node()}, {d.buf()}); - - l.root_stmt()->accept(&analyzer); - - // Sanity test: Output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b.node())); - - // Second loop depends on first loop. - auto c_loop = l.getLoopStmtsFor(c)[0]; - auto d_loop = l.getLoopStmtsFor(d)[0]; - ASSERT_TRUE(analyzer.dependsDirectly(d_loop, c_loop)); - - // Reduction depends on both inputs. - auto reduces = NodeFinder::find(l.root_stmt()); - ASSERT_TRUE(analyzer.dependsIndirectly(reduces[0], a.node())); - ASSERT_TRUE(analyzer.dependsIndirectly(reduces[0], b.node())); -} - -TEST(MemDependency, MemDependencyCheckerComputeGEMM) { - int M = 1024; - int N = 1024; - int K = 2048; - using namespace analysis; - - BufHandle AP("A", {M, K}, kFloat); - BufHandle BP("B", {K, N}, kFloat); - Tensor CT = Reduce( - "gemm", - {M, N}, - Sum(), - [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { - return AP.load(m, k) * BP.load(k, n); - }, - {K}); - LoopNest loop({CT}); - - { - auto const& loops = loop.getLoopStmtsFor(CT); - ForPtr m = loops[0]; - loop.splitWithMask(m, 4); - } - { - auto const& loops = loop.getLoopStmtsFor(CT); - ForPtr n = loops[2]; - loop.splitWithMask(n, 16); - } - // mo, mi, no, ni, k -> - // mo, no, mi, ni, k - { - auto const& loops = loop.getLoopStmtsFor(CT); - ForPtr mi = loops[1]; - ForPtr no = loops[2]; - loop.reorderAxis(mi, no); - } - // mo, no, mi, ni, k -> - // mo, no, mi, k, ni - { - auto const& loops = loop.getLoopStmtsFor(CT); - ForPtr ni = loops[3]; - ForPtr k = loops[4]; - loop.reorderAxis(ni, k); - } - // mo, no, mi, k, ni -> - // mo, no, k, mi, ni - { - auto const& loops = loop.getLoopStmtsFor(CT); - ForPtr mi = loops[2]; - ForPtr k = loops[3]; - loop.reorderAxis(mi, k); - } - { - auto const& loops = loop.getLoopStmtsFor(CT); - loop.cacheAccesses(CT.buf(), "C_regs", loops[2]); - } - - MemDependencyChecker analyzer_unlowered( - loop.getInputBufs(), loop.getOutputBufs()); - - MemDependencyChecker analyzer_lowered( - loop.getInputBufs(), loop.getOutputBufs()); - - // Test both unlowered and lowered form. - { - StmtPtr stmt = IRSimplifier::simplify(loop.root_stmt()); - stmt->accept(&analyzer_unlowered); - - // Outputs depend on inputs. - ASSERT_TRUE(analyzer_unlowered.dependsIndirectly(CT.buf(), AP.node())); - ASSERT_TRUE(analyzer_unlowered.dependsIndirectly(CT.buf(), BP.node())); - - // The last write to gemm should cover the total bound of the output. - std::shared_ptr outputAccess = - analyzer_unlowered.output(CT.buf()); - // A single dependency. - ASSERT_EQ(outputAccess->dependencies().size(), 1); - - // dependencies is a set with 1 element, so can just deref begin(). - std::shared_ptr gemmStore = - outputAccess->dependencies().begin()->second; - // Check its a store. - ASSERT_EQ(gemmStore->type(), AccessType::Store); - - ASSERT_TRUE(indexBoundsEquals(outputAccess->bounds(), gemmStore->bounds())); - - // Likewise the first read from each input cover the entire range of the - // input. - auto aInput = analyzer_unlowered.input(AP.node()); - auto bInput = analyzer_unlowered.input(BP.node()); - - // A single dependent each. - ASSERT_EQ(aInput->dependents().size(), 1); - ASSERT_EQ(bInput->dependents().size(), 1); - - // They're both loads. - std::shared_ptr aLoad = aInput->dependents().begin()->second; - std::shared_ptr bLoad = bInput->dependents().begin()->second; - ASSERT_EQ(aLoad->type(), AccessType::Load); - ASSERT_EQ(bLoad->type(), AccessType::Load); - - ASSERT_TRUE(indexBoundsEquals(aInput->bounds(), aLoad->bounds())); - ASSERT_TRUE(indexBoundsEquals(bInput->bounds(), bLoad->bounds())); - } - - loop.prepareForCodegen(); - SimpleIREvaluator cg(loop.root_stmt(), {AP, BP, CT}); - - // now check lowered dependency graph. - { - StmtPtr stmt = IRSimplifier::simplify(cg.stmt()); - stmt->accept(&analyzer_lowered); - - // Lowering will change the dimensionality of all bounds due to index - // flattening and will insert Allocates and Frees. - - auto history_before = analyzer_unlowered.getHistory(); - auto history_after = analyzer_lowered.getHistory(); - - ASSERT_EQ(history_before.size() + 2, history_after.size()); - - // Filter out the alloc/free; - auto isAllocFree = [](const auto& info) { - return info->type() == AccessType::Alloc || - info->type() == AccessType::Free; - }; - history_after.erase( - std::remove_if(history_after.begin(), history_after.end(), isAllocFree), - history_after.end()); - - ASSERT_EQ(history_before.size(), history_after.size()); - - for (size_t i = 0; i < history_before.size(); ++i) { - ASSERT_EQ(history_before[i]->type(), history_after[i]->type()); - ASSERT_EQ(history_before[i]->var(), history_after[i]->var()); - - if (history_before[i]->dependencies().size() != - history_after[i]->dependencies().size()) { - // Must depend on an Alloc. - ASSERT_TRUE(std::any_of( - history_after[i]->dependencies().begin(), - history_after[i]->dependencies().end(), - [](const auto& pair) { - return pair.second->type() == AccessType::Alloc; - })); - - ASSERT_EQ( - history_before[i]->dependencies().size() + 1, - history_after[i]->dependencies().size()); - } - - if (history_before[i]->dependents().size() != - history_after[i]->dependents().size()) { - // Must depend on an Free. - ASSERT_TRUE(std::any_of( - history_after[i]->dependents().begin(), - history_after[i]->dependents().end(), - [](const auto& pair) { - return pair.second->type() == AccessType::Free; - })); - - ASSERT_EQ( - history_before[i]->dependents().size() + 1, - history_after[i]->dependents().size()); - } - - // Inputs and outputs are not flattened, only accesses. - if (history_before[i]->type() == AccessType::Input || - history_before[i]->type() == AccessType::Output) { - ASSERT_EQ( - history_before[i]->bounds().size(), - history_after[i]->bounds().size()); - ASSERT_TRUE(indexBoundsEquals( - history_before[i]->bounds(), history_after[i]->bounds())); - } else { - ASSERT_EQ(history_after[i]->bounds().size(), 1); - ExprPtr flat_bounds = alloc(1); - - for (auto& b : history_before[i]->bounds()) { - flat_bounds = - alloc(flat_bounds, alloc(b.end, alloc(1))); - - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - ASSERT_TRUE(exprEquals(b.start, history_after[i]->bounds()[0].start)); - } - - flat_bounds = IRSimplifier::simplify(flat_bounds); - ExprPtr after_bounds = IRSimplifier::simplify( - alloc(history_after[i]->bounds()[0].end, alloc(1))); - ASSERT_TRUE(exprEquals(flat_bounds, after_bounds)); - } - } - } -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_memplanning.cpp b/test/cpp/tensorexpr/test_memplanning.cpp deleted file mode 100644 index f5ee8747650f..000000000000 --- a/test/cpp/tensorexpr/test_memplanning.cpp +++ /dev/null @@ -1,708 +0,0 @@ -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -extern void checkIR(StmtPtr s, const std::string& pattern); - -TEST(BufLiveRange, SingleRangeLine) { - VarHandle i("i", kInt), j("j", kInt); - BufHandle a("a", {32}, kFloat); - BufHandle b("b", {32, 32}, kFloat); - - // Construct Stmt: - // { - // for (int i = 0; i < 32; i++) { - // a[i] = 0; - // for (int j = 0; j < 32; j++) { - // a[i] = (a[i]) + (b[i, j]); - // } - // } - // } - - StorePtr aInit = Store::make(a, {i}, 0); - ExprHandle reduce = a.load({i}) + b.load({i, j}); - StorePtr aReduce = Store::make(a, {i}, reduce); - StmtPtr loop = - For::make(i, 0, 32, Block::make({aInit, For::make(j, 0, 32, aReduce)})); - - StmtPtr stmt = Block::make({loop}); - - auto range = BufLiveRange::liveRange(stmt, a.node()); - ASSERT_TRUE(std::get<0>(range) == 0); - ASSERT_TRUE(std::get<1>(range) == 0); -} - -TEST(BufLiveRange, MulRangeLine) { - VarHandle i("i", kInt); - BufHandle a("a", {32}, kFloat); - BufHandle b("b", {32}, kFloat); - - // Construct Stmt: - // { - // for (int i = 0; i < 32; i++) { - // if (i<10 ? 1 : 0) { - // a[i] = i + i; - // b[i] = i * i; - // } - // } - // for (int i = 0; i < 32; i++) { - // if (i>10 ? 1 : 0) { - // a[i] = i * i; - // b[i] = i + i; - // } - // } - // } - - StorePtr aStore_1 = Store::make(a, {i}, i + i); - StorePtr bStore_1 = Store::make(b, {i}, i * i); - StmtPtr loop_1 = For::make( - i, 0, 32, Cond::make(i < 10, Block::make({aStore_1, bStore_1}), NULL)); - - StorePtr aStore_2 = Store::make(a, {i}, i * i); - StorePtr bStore_2 = Store::make(b, {i}, i + i); - StmtPtr loop_2 = For::make( - i, 0, 32, Cond::make(i > 10, Block::make({aStore_2, bStore_2}), NULL)); - - StmtPtr stmt = Block::make({loop_1, loop_2}); - - auto range_a = BufLiveRange::liveRange(stmt, a.node()); - ASSERT_TRUE(std::get<0>(range_a) == 0); - ASSERT_TRUE(std::get<1>(range_a) == 1); - - auto range_b = BufLiveRange::liveRange(stmt, b.node()); - ASSERT_TRUE(std::get<0>(range_b) == 0); - ASSERT_TRUE(std::get<1>(range_b) == 1); -} - -TEST(MemPlanning, MemReuseWithTypeCast) { - int M = 4; - int N = 4; - int K = 4; - - BufHandle AP("A", {M, K}, kFloat); - BufHandle BP("B", {K, N}, kFloat); - - Tensor CT = Reduce( - "gemm", - {M, N}, - Sum(), - [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { - return AP.load(m, k) * BP.load(k, n); - }, - {K}); - Tensor DT = - Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return CompareSelect::make( - CT.load(m, n), 0.0f, 0.0f, CT.load(m, n), kLT); - }); - Tensor ET = - Compute("E", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return Cast::make(kQUInt8, DT.load(m, n) + DT.load(m, n)); - }); - Tensor FT = - Compute("F", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return ET.load(m, n); - }); - StmtPtr stmt = - tensorexpr::Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()}); - - // Constructed stmt: - // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], - // E [2, 3]. The dimensions of 'gemm' and 'E' are the same but their types are - // different: 'E' type quint8 < 'gemm' type float. We'll reuse 'gemm' for 'E' - // with typecasting. - //{ - // for (int i = 0; i < 4; i++) { - // for (int i_1 = 0; i_1 < 4; i_1++) { - // gemm[i, i_1] = float(0); - // for (int i_2 = 0; i_2 < 4; i_2++) { - // gemm[i, i_1] = ReduceOp((gemm[i, i_1]) + (A[i, i_2]) * (B[i_2, - // i_1]), reduce_args={i_2}); - // } - // } - // } - // for (int i_3 = 0; i_3 < 4; i_3++) { - // for (int i_4 = 0; i_4 < 4; i_4++) { - // relu[i_3, i_4] = (gemm[i_3, i_4])<0.f ? 0.f : (gemm[i_3, i_4]); - // } - // } - // for (int i_5 = 0; i_5 < 4; i_5++) { - // for (int i_6 = 0; i_6 < 4; i_6++) { - // E[i_5, i_6] = quint8((relu[i_5, i_6]) + (relu[i_5, i_6])); - // } - // } - // for (int i_7 = 0; i_7 < 4; i_7++) { - // for (int i_8 = 0; i_8 < 4; i_8++) { - // F[i_7, i_8] = E[i_7, i_8]; - // } - // } - //} - - LoopNest l(stmt, {FT.buf()}); - l.prepareForCodegen(); - SimpleIREvaluator cg(Stmt::clone(l.root_stmt()), {AP, BP, FT}); - - checkIR(cg.stmt(), R"IR( -# CHECK: Allocate(gemm); // dtype=float, dims=[4, 4] -# CHECK: Allocate(relu); // dtype=float, dims=[4, 4] -# CHECK: Alias(E,gemm); -# CHECK: Free(relu); -# CHECK: Free(gemm))IR"); - - PaddedBuffer a_v(M, K, "a"); - PaddedBuffer b_v(K, N, "b"); - PaddedBuffer o1(M, N, "e_before"); - PaddedBuffer o2(M, N, "e_after"); - - for (const auto m : c10::irange(M)) { - for (const auto k : c10::irange(K)) { - a_v(m, k) = at::randn({1}).item().to(); - } - } - - for (const auto k : c10::irange(K)) { - for (const auto n : c10::irange(N)) { - b_v(k, n) = at::randn({1}).item().to(); - } - } - - cg.call({a_v, b_v, o1}); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen cg_llvm(Stmt::clone(l.root_stmt()), {AP, BP, FT}); - - checkIR(cg_llvm.stmt(), R"IR( -# CHECK: Allocate(gemm); // dtype=float, dims=[4, 4] -# CHECK: Allocate(relu); // dtype=float, dims=[4, 4] -# CHECK: Alias(E,gemm); -# CHECK: Free(relu); -# CHECK: Free(gemm))IR"); - - cg_llvm.call({a_v, b_v, o2}); - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - ExpectAllNear(o1, o2, 1e-5); -#endif -} - -TEST(MemPlanning, NoMemReuseForLargerType) { - int M = 4; - int N = 4; - int K = 4; - - BufHandle AP("A", {M, K}, kShort); - BufHandle BP("B", {K, N}, kShort); - - Tensor CT = Reduce( - "gemm", - {M, N}, - Sum(), - [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { - return AP.load(m, k) * BP.load(k, n); - }, - {K}); - auto zero = Cast::make(CT.buf()->dtype(), 0); - Tensor DT = - Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return CompareSelect::make( - CT.load(m, n), zero, zero, CT.load(m, n), kLT); - }); - Tensor ET = - Compute("E", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return Cast::make(kFloat, DT.load(m, n) + DT.load(m, n)); - }); - Tensor FT = - Compute("F", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return ET.load(m, n); - }); - StmtPtr stmt = - tensorexpr::Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()}); - - // Constructed stmt: - // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], - // E [2, 3]. The dimensions of 'gemm' and 'E' are the same but their types are - // different: 'E' type float > 'gemm' type int16. We won't reuse 'gemm' for - // 'E'. - //{ - // for (int i = 0; i < 4; i++) { - // for (int i_1 = 0; i_1 < 4; i_1++) { - // gemm[i, i_1] = int16_t(0); - // for (int i_2 = 0; i_2 < 4; i_2++) { - // gemm[i, i_1] = ReduceOp((gemm[i, i_1]) + (A[i, i_2]) * (B[i_2, - // i_1]), reduce_args={i_2}); - // } - // } - // } - // for (int i_3 = 0; i_3 < 4; i_3++) { - // for (int i_4 = 0; i_4 < 4; i_4++) { - // relu[i_3, i_4] = (gemm[i_3, i_4]) a_v(M, K, "a"); - PaddedBuffer b_v(K, N, "b"); - PaddedBuffer o1(M, N, "e_before"); - PaddedBuffer o2(M, N, "e_after"); - - for (const auto m : c10::irange(M)) { - for (const auto k : c10::irange(K)) { - a_v(m, k) = at::randn({1}).item().to(); - } - } - - for (const auto k : c10::irange(K)) { - for (const auto n : c10::irange(N)) { - b_v(k, n) = at::randn({1}).item().to(); - } - } - - cg.call({a_v, b_v, o1}); - -#ifdef TORCH_ENABLE_LLVM - LLVMCodeGen cg_llvm(Stmt::clone(l.root_stmt()), {AP, BP, FT}); - - checkIR(cg_llvm.stmt(), R"IR( -# CHECK: Allocate(gemm); // dtype=int16_t, dims=[4, 4] -# CHECK: Allocate(relu); // dtype=int16_t, dims=[4, 4] -# CHECK: Allocate(E); // dtype=float, dims=[4, 4] -# CHECK: Free(E); -# CHECK: Free(relu); -# CHECK: Free(gemm))IR"); - - cg_llvm.call({a_v, b_v, o2}); - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - ExpectAllNear(o1, o2, 1e-5); -#endif -} - -TEST(MemPlanning, SameBufSizeMemReuse) { - int M = 1024; - int N = 1024; - int K = 2048; - - BufHandle AP("A", {M, K}, kFloat); - BufHandle BP("B", {K, N}, kFloat); - - Tensor CT = Reduce( - "gemm", - {M, N}, - Sum(), - [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { - return AP.load(m, k) * BP.load(k, n); - }, - {K}); - Tensor DT = - Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - auto zero = Cast::make(CT.buf()->dtype(), 0); - return CompareSelect::make( - CT.load(m, n), zero, zero, CT.load(m, n), kLT); - }); - Tensor ET = - Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return DT.load(m, n) + DT.load(m, n); - }); - Tensor FT = - Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return ET.load(m, n) * ET.load(m, n); - }); - auto stmt = Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()}); - - // Constructed stmt: - // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], - // add [2, 3] Buffer 'gemm' and 'add' are the same size; we'll reuse 'gemm' - // for 'add'. - //{ - // for (int M = 0; M < 1024; M++) { - // for (int N = 0; N < 1024; N++) { - // gemm[M, N] = float(0); - // for (int K = 0; K < 2048; K++) { - // gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]), - // reduce_args={K}); - // } - // } - // } - // for (int M_1 = 0; M_1 < 1024; M_1++) { - // for (int N_1 = 0; N_1 < 1024; N_1++) { - // relu[M_1, N_1] = (gemm[M_1, N_1])dtype(), 0); - return CompareSelect::make( - CT.load(m, n), zero, zero, CT.load(m, n), kLT); - }); - Tensor ET = - Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return DT.load(m, n) + DT.load(m, n); - }); - Tensor FT = - Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return ET.load(m, n) * ET.load(m, n); - }); - Tensor GT = - Compute("sub", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return FT.load(m, n) - ET.load(m, n); - }); - - auto stmt = - Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt(), GT.stmt()}); - - // Constructed stmt: - // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], - // add [2, 3], mul [3, 4] Buffer 'gemm', 'relu, ''add' and 'mul' are the same - // size; we'll reuse 'gemm' for 'add', and reuse 'relu' for 'mul' - //{ - // for (int M = 0; M < 1024; M++) { - // for (int N = 0; N < 1024; N++) { - // gemm[M, N] = float(0); - // for (int K = 0; K < 2048; K++) { - // gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]), - // reduce_args={K}); - // } - // } - // } - // for (int M_1 = 0; M_1 < 1024; M_1++) { - // for (int N_1 = 0; N_1 < 1024; N_1++) { - // relu[M_1, N_1] = (gemm[M_1, N_1])dtype(), 0); - return CompareSelect::make( - CT.load(m, n), zero, zero, CT.load(m, n), kLT); - }); - Tensor ET = - Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return DT.load(m, n) + DT.load(m, n); - }); - Tensor FT = - Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return ET.load(m, n) * ET.load(m, n); - }); - Tensor GT = - Compute("sub", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return FT.load(m, n) - 1; - }); - Tensor HT = - Compute("div", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { - return GT.load(m, n) / 2; - }); - - auto stmt = Block::make( - {CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt(), GT.stmt(), HT.stmt()}); - - // Constructed stmt: - // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], - // add [2, 3], mul [3, 4], sub [4, 5] Buffer 'gemm', 'relu, ''add', 'mul' and - // 'sub' are the same size; we'll reuse 'gemm' for 'add', reuse 'relu' for - // 'mul', and reuse 'gemm' for 'sub'. - //{ - // for (int M = 0; M < 1024; M++) { - // for (int N = 0; N < 1024; N++) { - // gemm[M, N] = float(0); - // for (int K = 0; K < 2048; K++) { - // gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]), - // reduce_args={K}); - // } - // } - // } - // for (int M_1 = 0; M_1 < 1024; M_1++) { - // for (int N_1 = 0; N_1 < 1024; N_1++) { - // relu[M_1, N_1] = (gemm[M_1, N_1])dtype(), 0); - return CompareSelect::make( - CT.load(m, n), zero, zero, CT.load(m, n), kLT); - }); - Tensor ET = Compute( - "add", {M * 2, N * 2}, [&](const ExprHandle& em, const ExprHandle& en) { - return DT.load(em / 2, en / 2) + DT.load(em / 2, en / 2); - }); - Tensor FT = Compute( - "mul", {M * 2, N * 2}, [&](const ExprHandle& fm, const ExprHandle& fn) { - return ET.load(fm, fn) * ET.load(fm, fn); - }); - auto stmt = Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()}); - - // Constructed stmt: - // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], - // add [2, 3] We do not reuse buffer 'gemm' for 'add' because the size of - // buffer 'gemm' is smaller. - //{ - // for (int M = 0; M < 1024; M++) { - // for (int N = 0; N < 1024; N++) { - // gemm[M, N] = float(0); - // for (int K = 0; K < 2048; K++) { - // gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]), - // reduce_args={K}); - // } - // } - // } - // for (int M_1 = 0; M_1 < 1024; M_1++) { - // for (int N_1 = 0; N_1 < 1024; N_1++) { - // relu[M_1, N_1] = (gemm[M_1, N_1]) -#include -#include -#include -#include -#include - -using namespace torch::jit::tensorexpr; - -using Tensors = std::vector; -using Args = std::vector; -std::unique_ptr compile( - const Args& inputs, - const Tensors& outputs) { - LoopNest nest({outputs}); - nest.prepareForCodegen(); - nest.simplify(); - auto join = inputs; - join.insert(join.end(), outputs.begin(), outputs.end()); - return std::make_unique(nest.root_stmt(), join); -} - -TEST(Ops, Sum) { - constexpr int M = 8; - constexpr int N = 16; - std::vector testDims = {{0}, {1}, {0, 1}}; - std::vector> outputShapes = {{N}, {M}, {}}; - for (unsigned idx = 0; idx < testDims.size(); idx++) { - const auto& dims = testDims[idx]; - const auto& outShape = outputShapes[idx]; - - BufHandle a("a", {M, N}, kFloat); - std::vector outStrides = - c10::fmap(make_contiguous_strides(outShape)); - Tensor b = computeSum( - {a, dims, false}, outShape, outStrides, c10::kFloat, at::kCPU); - auto cg = compile({a}, {b}); - - auto at = at::arange(M * N, at::kFloat).view({M, N}); - auto ref = at::sum(at, dims); - auto bt = at::empty_like(ref); - - cg->call({at.data_ptr(), bt.data_ptr()}); - - ASSERT_TRUE(at::allclose(bt, ref)); - } -} - -TEST(Ops, ChannelsLastSum) { - constexpr int A = 2; - constexpr int B = 3; - constexpr int C = 4; - constexpr int D = 5; - constexpr int E = 6; - std::vector testDims = {{0}, {1}, {0, 1}}; - - std::vector> outputShapes = { - {B, C, D, E}, {A, C, D, E}, {C, D, E}}; - for (unsigned idx = 0; idx < testDims.size(); idx++) { - const auto& dims = testDims[idx]; - const auto& outShape = outputShapes[idx]; - - BufHandle a("a", {A, B, C, D, E}, kFloat); - std::vector outStrides = - c10::fmap(make_channels_last_strides(outShape)); - Tensor b = computeSum( - {a, dims, false}, outShape, outStrides, c10::kFloat, at::kCPU); - auto cg = compile({a}, {b}); - - auto at = at::arange(A * B * C * D * E, at::kFloat).view({A, B, C, D, E}); - auto ref = at::sum(at, dims); - auto bt = at::empty_like(ref); - - cg->call({at.data_ptr(), bt.data_ptr()}); - - ASSERT_TRUE(at::allclose(bt, ref)); - } -} diff --git a/test/cpp/tensorexpr/test_quantization.cpp b/test/cpp/tensorexpr/test_quantization.cpp deleted file mode 100644 index af6b539ff33e..000000000000 --- a/test/cpp/tensorexpr/test_quantization.cpp +++ /dev/null @@ -1,452 +0,0 @@ -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "torch/csrc/jit/tensorexpr/eval.h" -#include "torch/csrc/jit/tensorexpr/ir.h" - -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; -using SimpleIRExprEval = ExprEval; -using namespace torch::indexing; -using namespace torch::jit::tensorexpr; - -class Quantization : public ::testing::Test { - public: - void SetUp() override { - getTEMustUseLLVMOnCPU() = false; - } -}; - -TEST_F(Quantization, QuantDequantInt8) { - const auto graph_string = R"IR( - graph(%x.1 : Float(2, 2, strides=[2, 1], device=cpu)): - %2 : int = prim::Constant[value=12]() - %3 : int = prim::Constant[value=13]() - %4 : float = prim::Constant[value=0.1]() - %q.1 : QInt8(2, 2) = aten::quantize_per_tensor(%x.1, %4, %3, %2) - %6 : Float(2, 2) = aten::dequantize(%q.1) - return (%6))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto q = at::quantize_per_tensor(x, 0.1f, 13, at::kQInt8); - auto y_expected = at::dequantize(q); - TensorExprKernel k(graph); - std::vector inputs = {x}; - StmtPtr s = k.getCodeGenStmt(); - - std::vector stack = fmap(inputs); - k.run(stack); - auto y = stack[0].toTensor(); - bool check = at::allclose(y_expected, y); - if (!check) { - std::cout << "y_expected:\n" << y_expected << std::endl; - std::cout << "y:\n" << y << std::endl; - } - TORCH_CHECK_EQ(check, 1); -} - -TEST_F(Quantization, QuantDequantUInt8) { - const auto graph_string = R"IR( - graph(%x.1 : Float(2, 2, strides=[2, 1], device=cpu)): - %2 : int = prim::Constant[value=13]() - %3 : int = prim::Constant[value=122]() - %4 : float = prim::Constant[value=0.1]() - %q.1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x.1, %4, %3, %2) - %6 : Float(2, 2) = aten::dequantize(%q.1) - return (%6))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x = 2 * at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto q = at::quantize_per_tensor(x, 0.1f, 122, at::kQUInt8); - auto y_expected = at::dequantize(q); - TensorExprKernel k(graph); - std::vector inputs = {x}; - StmtPtr s = k.getCodeGenStmt(); - - std::vector stack = fmap(inputs); - k.run(stack); - auto y = stack[0].toTensor(); - bool check = at::allclose(y_expected, y); - if (!check) { - std::cout << "y_expected:\n" << y_expected << std::endl; - std::cout << "y:\n" << y << std::endl; - } - TORCH_CHECK_EQ(check, 1); -} - -TEST_F(Quantization, QuantDequantUInt8_NLC) { - const auto graph_string = R"IR( - graph(%x.1 : Float(1, 2, 2, strides=[4, 1, 2], device=cpu)): - %2 : int = prim::Constant[value=13]() - %3 : int = prim::Constant[value=122]() - %4 : float = prim::Constant[value=0.1]() - %q.1 : QUInt8(1, 2, 2) = aten::quantize_per_tensor(%x.1, %4, %3, %2) - %6 : Float(1, 2, 2) = aten::dequantize(%q.1) - return (%6))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x = 2 * at::rand({1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - x.unsafeGetTensorImpl()->set_sizes_and_strides( - std::initializer_list{1, 2, 2}, {4, 1, 2}); - auto q = at::quantize_per_tensor(x, 0.1f, 122, at::kQUInt8); - auto y_expected = at::dequantize(q); - TensorExprKernel k(graph); - std::vector inputs = {x}; - StmtPtr s = k.getCodeGenStmt(); - - std::vector stack = fmap(inputs); - k.run(stack); - auto y = stack[0].toTensor(); - bool check = at::allclose(y_expected, y); - if (!check) { - std::cout << "x:\n" << x << std::endl; - std::cout << "y_expected:\n" << y_expected << std::endl; - std::cout << "y:\n" << y << std::endl; - } - TORCH_CHECK_EQ(check, 1); -} - -at::Tensor quantized_add( - at::Tensor x1, - at::Tensor x2, - double scale, - int64_t zero) { - const auto qadd_op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("quantized::add", "") - .typed(); - return qadd_op.call(x1, x2, scale, zero); -} - -TEST_F(Quantization, QuantAddDequantInt8) { - const auto graph_string = R"IR( - graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)): - %2 : int = prim::Constant[value=12]() - %qz1 : int = prim::Constant[value=13]() - %qs1 : float = prim::Constant[value=0.1]() - %qz2 : int = prim::Constant[value=13]() - %qs2 : float = prim::Constant[value=0.1]() - %qza : int = prim::Constant[value=13]() - %qsa : float = prim::Constant[value=0.1]() - %q1 : QInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2) - %q2 : QInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2) - %qa : QInt8(2, 2) = quantized::add(%q1, %q2, %qsa, %qza) - %6 : Float(2, 2) = aten::dequantize(%qa) - return (%6))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQInt8); - auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQInt8); - auto qa = quantized_add(q1, q2, 0.1f, 13); - auto y_expected = at::dequantize(qa); - TensorExprKernel k(graph); - std::vector inputs = {x1, x2}; - StmtPtr s = k.getCodeGenStmt(); - - std::vector stack = fmap(inputs); - k.run(stack); - auto y = stack[0].toTensor(); - bool check = at::allclose(y_expected, y); - if (!check) { - std::cout << "x1:\n" << x1 << std::endl; - std::cout << "q1:\n" << q1 << std::endl; - std::cout << "x2:\n" << x2 << std::endl; - std::cout << "q2:\n" << q2 << std::endl; - std::cout << "y_expected:\n" << y_expected << std::endl; - std::cout << "y:\n" << y << std::endl; - } - TORCH_CHECK_EQ(check, 1); -} - -TEST_F(Quantization, QuantAddDequantUInt8) { - const auto graph_string = R"IR( - graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)): - %2 : int = prim::Constant[value=13]() - %qz1 : int = prim::Constant[value=13]() - %qs1 : float = prim::Constant[value=0.1]() - %qz2 : int = prim::Constant[value=13]() - %qs2 : float = prim::Constant[value=0.1]() - %qza : int = prim::Constant[value=13]() - %qsa : float = prim::Constant[value=0.1]() - %q1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2) - %q2 : QUInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2) - %qa : QUInt8(2, 2) = quantized::add(%q1, %q2, %qsa, %qza) - %6 : Float(2, 2) = aten::dequantize(%qa) - return (%6))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQUInt8); - auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQUInt8); - auto qa = quantized_add(q1, q2, 0.1f, 13); - auto y_expected = at::dequantize(qa); - - TensorExprKernel k(graph); - std::vector inputs = {x1, x2}; - StmtPtr s = k.getCodeGenStmt(); - - std::vector stack = fmap(inputs); - k.run(stack); - auto y = stack[0].toTensor(); - bool check = at::allclose(y_expected, y); - if (!check) { - std::cout << "x1:\n" << x1 << std::endl; - std::cout << "q1:\n" << q1 << std::endl; - std::cout << "x2:\n" << x2 << std::endl; - std::cout << "q2:\n" << q2 << std::endl; - std::cout << "y_expected:\n" << y_expected << std::endl; - std::cout << "y:\n" << y << std::endl; - } - TORCH_CHECK_EQ(check, 1); -} - -TEST_F(Quantization, QuantSigmoidDequantUInt8) { - const auto graph_string = R"IR( - graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu)): - %2 : int = prim::Constant[value=13]() - %qz1 : int = prim::Constant[value=13]() - %qs1 : float = prim::Constant[value=0.1]() - %q1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2) - %qa : QUInt8(2, 2) = aten::sigmoid(%q1) - %6 : Float(2, 2) = aten::dequantize(%qa) - return (%6))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQUInt8); - auto qs = at::sigmoid(q1); - auto y_expected = at::dequantize(qs); - - TensorExprKernel k(graph); - std::vector inputs = {x1}; - StmtPtr s = k.getCodeGenStmt(); - - std::vector stack = fmap(inputs); - k.run(stack); - auto y = stack[0].toTensor(); - bool check = at::allclose(y_expected, y); - if (!check) { - std::cout << "x1:\n" << x1 << std::endl; - std::cout << "q1:\n" << q1 << std::endl; - std::cout << "qs:\n" << qs << std::endl; - std::cout << "y_expected:\n" << y_expected << std::endl; - std::cout << "y:\n" << y << std::endl; - } - TORCH_CHECK_EQ(check, 1); -} - -at::Tensor quantized_mul( - at::Tensor x1, - at::Tensor x2, - double scale, - int64_t zero) { - const auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("quantized::mul", "") - .typed(); - return op.call(x1, x2, scale, zero); -} - -TEST_F(Quantization, QuantMulDequantUInt8) { - const auto graph_string = R"IR( - graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)): - %2 : int = prim::Constant[value=13]() - %qz1 : int = prim::Constant[value=13]() - %qs1 : float = prim::Constant[value=0.1]() - %qz2 : int = prim::Constant[value=13]() - %qs2 : float = prim::Constant[value=0.1]() - %qza : int = prim::Constant[value=13]() - %qsa : float = prim::Constant[value=0.1]() - %q1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2) - %q2 : QUInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2) - %qa : QUInt8(2, 2) = quantized::mul(%q1, %q2, %qsa, %qza) - %6 : Float(2, 2) = aten::dequantize(%qa) - return (%6))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQUInt8); - auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQUInt8); - auto qa = quantized_mul(q1, q2, 0.1f, 13); - auto y_expected = at::dequantize(qa); - - TensorExprKernel k(graph); - std::vector inputs = {x1, x2}; - StmtPtr s = k.getCodeGenStmt(); - - std::vector stack = fmap(inputs); - k.run(stack); - auto y = stack[0].toTensor(); - bool check = at::allclose(y_expected, y); - if (!check) { - std::cout << "x1:\n" << x1 << std::endl; - std::cout << "q1:\n" << q1 << std::endl; - std::cout << "x2:\n" << x2 << std::endl; - std::cout << "q2:\n" << q2 << std::endl; - std::cout << "y_expected:\n" << y_expected << std::endl; - std::cout << "y:\n" << y << std::endl; - } - TORCH_CHECK_EQ(check, 1); -} - -TEST_F(Quantization, QuantUpsampleNearst2dDequantUInt8) { - const auto graph_string = R"IR( - graph(%x : Float(1, 1, 4, 4, strides=[16, 16, 4, 1], device=cpu)): - %2 : int = prim::Constant[value=13]() - %4 : NoneType = prim::Constant() - %3 : int[] = prim::Constant[value=[6, 6]]() - %qz : int = prim::Constant[value=13]() - %qs : float = prim::Constant[value=0.1]() - %q : QUInt8(1, 1, 4, 4) = aten::quantize_per_tensor(%x, %qs, %qz, %2) - %qu : QUInt8(1, 1, 6, 6) = aten::upsample_nearest2d(%q, %3, %4) - %6 : Float(1, 1, 6, 6) = aten::dequantize(%qu) - return (%6))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x = at::rand({1, 1, 4, 4}, TensorOptions(kCPU).dtype(at::kFloat)); - auto q = at::quantize_per_tensor(x, 0.1f, 13, at::kQUInt8); - auto qu = at::upsample_nearest2d(q, {6, 6}); - auto y_expected = at::dequantize(qu); - - TensorExprKernel k(graph); - std::vector inputs = {x}; - StmtPtr s = k.getCodeGenStmt(); - - std::vector stack = fmap(inputs); - k.run(stack); - auto y = stack[0].toTensor(); - bool check = at::allclose(y_expected, y); - if (!check) { - std::cout << "x:\n" << x << std::endl; - std::cout << "q:\n" << q << std::endl; - std::cout << "qu:\n" << qu << std::endl; - std::cout << "y_expected:\n" << y_expected << std::endl; - std::cout << "y:\n" << y << std::endl; - } - TORCH_CHECK_EQ(check, 1); -} - -TEST_F(Quantization, UpsampleNearst2d) { - const auto graph_string = R"IR( - graph(%x : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu)): - %4 : NoneType = prim::Constant() - %3 : int[] = prim::Constant[value=[4, 4]]() - %u : Float(1, 1, 4, 4) = aten::upsample_nearest2d(%x, %3, %4) - return (%u))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto y_expected = at::upsample_nearest2d(x, {4, 4}); - - TensorExprKernel k(graph); - std::vector inputs = {x}; - StmtPtr s = k.getCodeGenStmt(); - - std::vector stack = fmap(inputs); - k.run(stack); - auto y = stack[0].toTensor(); - bool check = at::allclose(y_expected, y); - if (!check) { - std::cout << "x:\n" << x << std::endl; - std::cout << "y_expected:\n" << y_expected << std::endl; - std::cout << "y:\n" << y << std::endl; - } - TORCH_CHECK_EQ(check, 1); -} - -at::Tensor quantized_cat( - c10::List const& xs, - int64_t dim, - double scale, - int64_t zero) { - const auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("quantized::cat", "") - .typed const&, - int64_t, - std::optional, - std::optional)>(); - return op.redispatch( - DispatchKeySet({DispatchKey::QuantizedCPU}), xs, dim, scale, zero); -} - -TEST_F(Quantization, QuantCatDequantUInt8) { - const auto graph_string = R"IR( - graph(%x : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu), %y : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu), %z : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu)): - %qdt : int = prim::Constant[value=13]() - %qxz : int = prim::Constant[value=13]() - %qxs : float = prim::Constant[value=0.1]() - %qyz : int = prim::Constant[value=16]() - %qys : float = prim::Constant[value=0.15]() - %qzz : int = prim::Constant[value=19]() - %qzs : float = prim::Constant[value=0.2]() - %qx : QUInt8(1, 1, 2, 2) = aten::quantize_per_tensor(%x, %qxs, %qxz, %qdt) - %qy : QUInt8(1, 1, 2, 2) = aten::quantize_per_tensor(%y, %qys, %qyz, %qdt) - %qz : QUInt8(1, 1, 2, 2) = aten::quantize_per_tensor(%z, %qzs, %qzz, %qdt) - %catx : Tensor[] = prim::ListConstruct(%qx, %qy, %qz) - %catd : int = prim::Constant[value=0]() - %qcat : QUInt8(3, 1, 2, 2) = quantized::cat(%catx, %catd, %qxs, %qxz) - %cat : Float(3, 1, 2, 2) = aten::dequantize(%qcat) - return (%cat))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - auto x = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto y = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto z = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); - auto qx = at::quantize_per_tensor(x, 0.1f, 13, at::kQUInt8); - auto qy = at::quantize_per_tensor(y, 0.15f, 16, at::kQUInt8); - auto qz = at::quantize_per_tensor(z, 0.2f, 19, at::kQUInt8); - auto qcat = quantized_cat({qx, qy, qz}, 0, 0.1f, 13); - auto expected = at::dequantize(qcat); - - TensorExprKernel k(graph); - std::vector inputs = {x, y, z}; - StmtPtr s = k.getCodeGenStmt(); - - std::vector stack = fmap(inputs); - k.run(stack); - auto result = stack[0].toTensor(); - bool check = at::allclose(expected, result); - if (!check) { - std::cout << "x:\n" << x << std::endl; - std::cout << "y:\n" << y << std::endl; - std::cout << "z:\n" << z << std::endl; - std::cout << "qx:\n" << qx << std::endl; - std::cout << "qy:\n" << qy << std::endl; - std::cout << "qz:\n" << qz << std::endl; - std::cout << "qcat:\n" << qcat << std::endl; - std::cout << "expected:\n" << expected << std::endl; - std::cout << "result:\n" << result << std::endl; - } - TORCH_CHECK_EQ(check, 1); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_reductions.cpp b/test/cpp/tensorexpr/test_reductions.cpp deleted file mode 100644 index fb83ab85b71e..000000000000 --- a/test/cpp/tensorexpr/test_reductions.cpp +++ /dev/null @@ -1,1928 +0,0 @@ -#include - -#include -#include -#include -#include -#include - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -TEST(Reductions, ReduceSum0D_1) { - const int M = 10; - - BufHandle b("b", {M}, kFloat); - std::vector in(M); - for (const auto j : c10::irange(M)) { - in[j] = j; - } - - std::vector out(M, -1.f); - - Tensor c = Reduce("sum", {M}, Sum(), b, {}); - LoopNest loop({c}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - for (const auto i : c10::irange(M)) { - ASSERT_EQ(out[i], in[i]); - } -} - -TEST(Reductions, ReduceSum0D_2) { - BufHandle b("b", {}, kFloat); - std::vector in(1); - in[0] = 77.7; - - std::vector out(1, -1.f); - - Tensor c = Reduce("sum", {}, Sum(), b, {}); - LoopNest loop({c}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - ASSERT_EQ(out[0], in[0]); -} - -// Sum an array to a single value. -TEST(Reductions, ReduceSum1D) { - BufHandle b("b", {10}, kFloat); - std::vector in(10); - for (const auto j : c10::irange(10)) { - in[j] = j; - } - - std::vector out(1, -1.f); - - Tensor c = Reduce("sum", {}, Sum(), b, {10}); - LoopNest loop({c}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - ASSERT_EQ(out[0], 45); -} -// Sum a 2D tensor to a 1D tensor with dynamic shapes. -TEST(Reductions, ReduceSum2D) { - const int M = 3; - const int N = 7; - - VarHandle m("m", kInt); - VarHandle n("n", kInt); - - BufHandle b("b", {m, n}, kFloat); - std::vector in(M * N); - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - in[i * N + j] = j; - } - } - - std::vector out(M, -1.f); - - Tensor c = Reduce("sum", {M}, Sum(), b, {N}); - LoopNest loop({c}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c, n, m}); - - cg.call({in, out, 5, 7}); - - float expected = 0; - for (const auto i : c10::irange(N)) { - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - expected += i; - } - - for (const auto i : c10::irange(M)) { - ASSERT_EQ(out[i], expected); - } -} - -// Sum a 3D tensor to both a 2D and 1D tensor, then reduce the 2D tensor flat to -// check our work. -TEST(Reductions, ReduceSum3D) { - const int M = 10; - VarHandle m("m", kInt); - - BufHandle b("b", {2, 3, m}, kFloat); - - Tensor c = Reduce("sum", {2, 3}, Sum(), b, {m}); - LoopNest loop({c}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c, m}); - - std::vector bData(2 * 3 * M, 0); - std::vector cData(2 * 3, 6.0f); - std::vector dData(2, 1.0f); - std::vector eData(2, 1.0f); - - for (int i = 0; i < 2 * 3; ++i) { - for (const auto j : c10::irange(M)) { - bData[i * M + j] = j; - } - } - - cg.call({bData, cData, M}); - float expected = 0; - for (const auto i : c10::irange(M)) { - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - expected += i; - } - - for (int i = 0; i < 2 * 3; ++i) { - ASSERT_EQ(cData[i], expected); - } - - Tensor d = Reduce("sum2", {2}, Sum(), b, {3, m}); - LoopNest loop2({d}); - loop2.prepareForCodegen(); - StmtPtr s2 = loop2.root_stmt(); - s2 = IRSimplifier::simplify(s2); - - SimpleIREvaluator cg2(s2, {b, d, m}); - cg2.call({bData, dData, M}); - - // We're combining an additional dimension of 3, so the sum is 3x. - expected = expected * 3; - - for (const auto i : c10::irange(2)) { - ASSERT_EQ(dData[i], expected); - } - - // This is the same as just reducing the original result across that axis. - BufHandle c_buf(c.buf()); - Tensor e = Reduce("sum3", {2}, Sum(), c_buf, {3}); - LoopNest loop3({e}); - loop3.prepareForCodegen(); - StmtPtr s3 = loop3.root_stmt(); - s3 = IRSimplifier::simplify(s3); - - SimpleIREvaluator cg3(s3, {c, e}); - cg3.call({cData, eData}); - - for (const auto i : c10::irange(2)) { - ASSERT_EQ(eData[i], expected); - } -} - -// Sum a large (10 D) Tensor 5 dimensions in. -TEST(Reductions, ReduceSum10D) { - BufHandle in_("in_", {2, 3, 2, 3, 2, 3, 2, 3, 2, 3}, kFloat); - const int InputSize = 2 * 3 * 2 * 3 * 2 * 3 * 2 * 3 * 2 * 3; - BufHandle out_("out_", {2, 3, 2, 3, 2}, kFloat); - const int OutputSize = 2 * 3 * 2 * 3 * 2; - - std::vector in(InputSize, 1.f); - std::vector out(OutputSize, -1.f); - - Tensor c = Reduce("sum", {2, 3, 2, 3, 2}, Sum(), in_, {3, 2, 3, 2, 3}); - LoopNest loop({c}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {in_, c}); - - cg.call({in, out}); - - // NOLINTNEXTLINE(bugprone-integer-division) - float expected = InputSize / OutputSize; - for (const auto i : c10::irange(OutputSize)) { - ASSERT_EQ(out[i], expected); - } -} - -// Reduce via Mul rather than Add using a custom Reducer. -TEST(Reductions, ReduceProduct) { - const int M = 4; - const int N = 4; - - BufHandle b("b", {M, N}, kFloat); - std::vector in(M * N); - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - in[i * N + j] = 2 + j; - } - } - - std::vector out(M, -1.f); - - Reducer product( - ExprHandle(1.f), [](ExprHandle a, ExprHandle b) { return a * b; }); - - Tensor c = Reduce("product", {M}, product, b, {N}); - LoopNest loop({c}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - - float expected = 1; - for (const auto i : c10::irange(N)) { - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - expected *= 2 + i; - } - - for (const auto i : c10::irange(M)) { - ASSERT_EQ(out[i], expected); - } -} - -// Maximum reductions. -TEST(Reductions, ReduceMax) { - BufHandle in_("b", {10}, kFloat); - - std::vector in(10); - std::vector out(1, -1.f); - for (const auto j : c10::irange(10)) { - in[j] = j; - } - - Tensor dm1 = Reduce("max", {}, Maximum(kFloat), in_, {10}); - - LoopNest loop({dm1}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - SimpleIREvaluator cg(s, {in_, dm1}); - - cg.call({in, out}); - - ASSERT_EQ(out[0], 9); - - BufHandle in2_("b", {2, 5}, kFloat); - std::vector out2(2, -1.f); - - Tensor m2d = Reduce("max", {2}, Maximum(kFloat), in2_, {5}); - - LoopNest loop2({m2d}); - loop2.prepareForCodegen(); - s = loop2.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg2(s, {in2_, m2d}); - cg2.call({in, out2}); - - ASSERT_EQ(out2[0], 4); - ASSERT_EQ(out2[1], 9); -} - -// Minimum reduction, with custom initialization. -TEST(Reductions, ReduceMinCustomInitializer) { - VarHandle minInit("minInit", kFloat); - BufHandle in_("b", {10}, kFloat); - - std::vector in(10); - std::vector out(1, -1.f); - for (const auto j : c10::irange(10)) { - in[j] = 10 + j; - } - - Tensor min = Reduce( - "min", - {}, - Minimum(ExprHandle(minInit)), - [&](ParameterList& v) { return in_.load(v); }, - {10}); - - LoopNest loop({min}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {in_, min, minInit}); - - // Works normally (note that out data starts lower than the correct - // minimum). - cg.call({in, out, std::numeric_limits::max()}); - ASSERT_EQ(out[0], 10); - - // With an initializer lower than the min, that's the min. - cg.call({in, out, 5.f}); - ASSERT_EQ(out[0], 5); -} - -// Example implementation of Any/All. -// TODO: this is very awkward without logical And/Or operators. -TEST(Reductions, ReduceAnyAll) { - VarHandle searchValue("searchValue", kInt); - BufHandle b("b", {4, 10}, kInt); - - Reducer anyEqSV(ExprHandle(0), [](ExprHandle a, ExprHandle b) { - return CompareSelect::make(a, 1, 1, b, kEQ); - }); - - Tensor any = Reduce( - "anyEqual", - {4}, - anyEqSV, - [&](const auto& i, const auto& j) { - return CompareSelect::make(b.load(i, j), searchValue, kEQ); - }, - {10}); - - LoopNest loop({any}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, any, searchValue}); - - std::vector in(40, 0); - std::vector out(4, 0); - - // input has 0-39 in 4 rows. - for (const auto i : c10::irange(40)) { - in[i] = i; - } - cg.call({in, out, 1}); - - // only the first row has 1 - ASSERT_EQ(out[0], 1); - ASSERT_EQ(out[1], 0); - ASSERT_EQ(out[2], 0); - ASSERT_EQ(out[3], 0); - - cg.call({in, out, 15}); - - // 15 in the 3rd row - ASSERT_EQ(out[0], 0); - ASSERT_EQ(out[1], 1); - ASSERT_EQ(out[2], 0); - ASSERT_EQ(out[3], 0); - - Reducer allGTSV(ExprHandle(1), [](ExprHandle a, ExprHandle b) { - return CompareSelect::make(a, 0, 0, b, kEQ); - }); - - Tensor allGreaterThan = Reduce( - "allGreaterThan", - {4}, - allGTSV, - [&](const auto& i, const auto& j) { - return CompareSelect::make(b.load(i, j), searchValue, kGT); - }, - {10}); - - LoopNest loop2({allGreaterThan}); - loop2.prepareForCodegen(); - s = loop2.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg2(s, {b, allGreaterThan, searchValue}); - - cg2.call({in, out, 11}); - - // 11 is in row 2. - ASSERT_EQ(out[0], 0); - ASSERT_EQ(out[1], 0); - ASSERT_EQ(out[2], 1); - ASSERT_EQ(out[3], 1); - - cg2.call({in, out, -3}); - - // All are positive. - ASSERT_EQ(out[0], 1); - ASSERT_EQ(out[1], 1); - ASSERT_EQ(out[2], 1); - ASSERT_EQ(out[3], 1); -} - -TEST(Reductions, ReduceMatmul2D) { - BufHandle tA("tA", {3, 2}, kFloat); - BufHandle tB("tB", {2, 3}, kFloat); - - std::vector tA_(6); - std::vector tB_(6); - - std::vector out(9, -1.f); - for (const auto i : c10::irange(3)) { - for (const auto j : c10::irange(2)) { - tA_[i * 2 + j] = i * 2 + j; - tB_[j * 3 + i] = i * 2 + j; - } - } - - Tensor mm = Reduce( - "mm", - {3, 3}, - Sum(), - [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { - return tA.load(m, k) * tB.load(k, n); - }, - {2}); - - LoopNest loop({mm}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {tA, tB, mm}); - cg.call({tA_, tB_, out}); - - std::vector expected( - {1.f, 3.f, 5.f, 3.f, 13.f, 23.f, 5.f, 23.f, 41.f}); - - for (const auto i : c10::irange(9)) { - ASSERT_EQ(out[i], expected[i]); - } -} - -TEST(Reductions, ReduceRfactorLike) { - BufHandle in("in", {10, 10}, kFloat); - std::vector in_(100); - for (const auto i : c10::irange(100)) { - in_[i] = i; - } - std::vector in_rf_(10, -2.f); - std::vector out(1, -1.f); - - Tensor l1 = Reduce("l1", {10}, Sum(), in, {10}); - BufHandle in_rf(l1.buf()); - - Tensor l2 = Reduce("l2", {}, Sum(), in_rf, {10}); - - LoopNest loop({l1, l2}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {in, l1, l2}); - cg.call({in_, in_rf_, out}); - - ASSERT_EQ(out[0], 99 * 50); -} - -TEST(Reductions, ReduceAsProducer) { - const int M = 10; - VarHandle m("m", kInt); - - BufHandle a("a", {2, 3}, kFloat); - BufHandle b("b", {2, 3, m}, kFloat); - - Tensor c = Reduce("sum", {2, 3}, Sum(), b, {m}); - Tensor d = - Compute("scale", {2, 3}, [&](const VarHandle& l, const VarHandle& n) { - return c.load(l, n) * a.load(l, n); - }); - LoopNest loop({d}, {c, d}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {a, b, d, m}); - - std::vector aData(2 * 3, 0); - std::vector bData(2 * 3 * M, 0); - std::vector dData(2 * 3, 6.0f); - - for (int i = 0; i < 2 * 3; ++i) { - aData[i] = 6 - i; - for (const auto j : c10::irange(M)) { - bData[i * M + j] = j; - } - } - - cg.call({aData, bData, dData, M}); - float expected = 0; - for (const auto i : c10::irange(M)) { - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - expected += i; - } - for (int i = 0; i < 2 * 3; ++i) { - ASSERT_EQ(dData[i], expected * (6 - i)); - } -} - -TEST(Reductions, ReduceAsConsumer) { - const int M = 10; - VarHandle m("m", kInt); - - BufHandle a("a", {2, 3, m}, kFloat); - BufHandle b("b", {2, 3, m}, kFloat); - - Tensor c = Compute( - "scale", - {2, 3, m}, - [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { - return b.load(l, n, m) * a.load(l, n, m); - }); - Tensor d = Reduce("sum", {2}, Sum(), c, {3, m}); - LoopNest loop({d}, {c, d}); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {a, b, d, m}); - - std::vector aData(2 * 3 * M, 0); - std::vector bData(2 * 3 * M, 0); - std::vector dData(2, 6.0f); - - for (int i = 0; i < 2 * 3; ++i) { - for (const auto j : c10::irange(M)) { - bData[i * M + j] = j + 1; - aData[i * M + j] = 6 - i; - } - } - - cg.call({aData, bData, dData, M}); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) - float expected[2] = {0, 0}; - for (const auto i : c10::irange(2)) { - for (const auto j : c10::irange(3)) { - for (const auto k : c10::irange(M)) { - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - expected[i] += (k + 1) * (6 - (i * 3 + j)); - } - } - } - - for (const auto i : c10::irange(2)) { - ASSERT_EQ(dData[i], expected[i]); - } -} - -TEST(Reductions, SplitReduceAxis) { - BufHandle in("in", {16, 8}, kFloat); - - std::vector in_(16 * 8); - for (const auto i : c10::irange(16)) { - for (const auto j : c10::irange(8)) { - in_[i * 8 + j] = i; - } - } - std::vector out(16, -1.f); - - Tensor tensor = Reduce("sum", {16}, Sum(), in, {8}); - LoopNest l({tensor}); - std::vector loops = l.getLoopStmtsFor(tensor); - LoopNest::splitWithTail(loops[1], 2); - - l.prepareForCodegen(); - - StmtPtr s = l.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {in, tensor}); - cg.call({in_, out}); - - for (const auto i : c10::irange(16)) { - ASSERT_EQ(out[i], i * 8); - } -} - -TEST(Reductions, SplitNonReduceAxis) { - BufHandle in("in", {16, 8}, kFloat); - - std::vector in_(16 * 8); - for (const auto i : c10::irange(16)) { - for (const auto j : c10::irange(8)) { - in_[i * 8 + j] = i; - } - } - std::vector out(16, -1.f); - Tensor tensor = Reduce("sum", {16}, Sum(), in, {8}); - LoopNest l({tensor}); - std::vector loops = l.getLoopStmtsFor(tensor); - LoopNest::splitWithTail(loops[0], 2); - LoopNest::splitWithTail(loops[0], 2); - - l.prepareForCodegen(); - - StmtPtr s = l.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {in, tensor}); - cg.call({in_, out}); - - for (const auto i : c10::irange(16)) { - ASSERT_EQ(out[i], i * 8); - } -} - -TEST(Reductions, ReorderedReductionInitializer) { - /* From the quip: - for k in 0..1: // blockIdx - for m in 0..128: - for n in 0..64: // threadIdx - SumOp(c(k, n), 0, a(k, m, n), {m}) - */ - - BufHandle in("in", {1, 12, 6}, kFloat); - std::vector in_(12 * 6, 1.f); - - Tensor tensor_ = Reduce("sum", {1, 12}, Sum(), in, {6}); - LoopNest l_({tensor_}); - - l_.prepareForCodegen(); - StmtPtr s_ = Stmt::clone(l_.root_stmt()); - s_ = IRSimplifier::simplify(s_); - - Tensor tensor = Reduce("sum", {1, 12}, Sum(), in, {6}); - LoopNest l({tensor}); - - auto loops = l.getLoopStmtsFor(tensor); - loops[0]->set_gpu_block_index(0); - loops[1]->set_gpu_thread_index(0); - - LoopNest::reorderAxis(loops[1], loops[2]); - - StmtPtr s = l.root_stmt(); - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - s = IRSimplifier::simplify(s); - - l.prepareForCodegen(); - - s = l.root_stmt(); - s = IRSimplifier::simplify(s); - - std::vector out1(16, -1.f); - SimpleIREvaluator cg(s_, {in, tensor_}); - cg.call({in_, out1}); - - std::vector out2(16, -1.f); - SimpleIREvaluator cg2(s, {in, tensor}); - cg2.call({in_, out2}); - - for (const auto i : c10::irange(16)) { - ASSERT_EQ(out1[i], out2[i]); - } -} - -TEST(Reductions, ReduceRfactor) { - const int M = 10; - const int N = 10; - VarHandle m("m", kInt); - VarHandle n("n", kInt); - - BufHandle b("b", {m, n}, kFloat); - std::vector in(M * N); - for (int j = 0; j < M * N; ++j) { - in[j] = j; - } - - std::vector out(1, -1.f); - - Tensor c = Reduce("sum", {}, Sum(), b, {m, n}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - auto c_body = loop.getAllWritesToBuf(c.buf())[1]; - ASSERT_TRUE(loop.rfactor(c_body, loops.at(0))); - auto rc = NodeFinder::find(loop.root_stmt()); - ASSERT_EQ(rc.size(), 2); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c, m, n}); - - cg.call({in, out, M, N}); - ASSERT_EQ(out[0], 4950); -} - -TEST(Reductions, Reduce3DRfactorInner) { - const int M = 10; - const int N = 10; - const int K = 10; - VarHandle m("m", kInt); - VarHandle n("n", kInt); - VarHandle k("k", kInt); - - BufHandle b("b", {m, n, k}, kFloat); - std::vector in(M * N * K); - for (int j = 0; j < M * N * K; ++j) { - in[j] = j; - } - - std::vector out(1, -1.f); - - Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - auto c_body = loop.getAllWritesToBuf(c.buf())[1]; - ASSERT_FALSE(loop.rfactor(c_body, loops.at(2))); - auto rc = NodeFinder::find(loop.root_stmt()); - ASSERT_EQ(rc.size(), 1); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c, m, n, k}); - - cg.call({in, out, M, N, K}); - ASSERT_EQ(out[0], 499500); -} - -TEST(Reductions, Reduce3DRfactorOuter) { - const int M = 10; - const int N = 10; - const int K = 10; - VarHandle m("m", kInt); - VarHandle n("n", kInt); - VarHandle k("k", kInt); - - BufHandle b("b", {m, n, k}, kFloat); - std::vector in(M * N * K); - for (int j = 0; j < M * N * K; ++j) { - in[j] = j; - } - - std::vector out(1, -1.f); - - Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - auto c_body = loop.getAllWritesToBuf(c.buf())[1]; - ASSERT_TRUE(loop.rfactor(c_body, loops.at(0))); - auto rc = NodeFinder::find(loop.root_stmt()); - ASSERT_EQ(rc.size(), 2); - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c, m, n, k}); - cg.call({in, out, M, N, K}); - ASSERT_EQ(out[0], 499500); -} - -TEST(Reductions, ReduceRepeatedInternalRfactor) { - BufHandle in_("in_", {2, 3, 4, 5, 6}, kFloat); - const int InputSize = 2 * 3 * 4 * 5 * 6; - - std::vector in(InputSize, 1.f); - std::vector out(1, -1.f); - std::vector ref(1, -1.f); - - Tensor c = Reduce("sum", {}, Sum(), in_, {2, 3, 4, 5, 6}); - LoopNest orig_loop({c}); - - // Try rfactoring N outer loops - for (const auto rfac_number : c10::irange(1, 5)) { - LoopNest refloop(orig_loop); - LoopNest loop(orig_loop); - refloop.prepareForCodegen(); - SimpleIREvaluator ref_cg( - IRSimplifier::simplify(refloop.root_stmt()), {in_, c}); - ref_cg.call({in, ref}); - - BufPtr tmp_buf = c.buf(); - - for (const auto idx : c10::irange(rfac_number)) { - auto reduce = loop.getAllWritesToBuf(tmp_buf)[1]; - ASSERT_TRUE(loop.rfactor( - reduce, loop.getLoopStmtsFor(tmp_buf).at(idx), &tmp_buf)); - } - - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {in_, c}); - cg.call({in, out}); - - ASSERT_EQ(ref[0], out[0]); - } -} - -// Split a reduction axis with a tail loop. -TEST(Reductions, ReduceSplitTail) { - const int M = 10; - const int N = 10; - const int K = 10; - - BufHandle b("b", {M, N, K}, kFloat); - std::vector in(M * N * K); - for (int j = 0; j < M * N * K; ++j) { - in[j] = j; - } - - for (const auto i : c10::irange(3)) { - std::vector out(M, -1.f); - - Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - LoopNest::splitWithTail(loops[i], 8); - - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - ASSERT_EQ(out[0], 4950); - } -} - -// Split a reduction axis cleanly so there is no tail loop. -TEST(Reductions, ReduceSplitNoTail) { - const int M = 10; - const int N = 10; - const int K = 10; - BufHandle b("b", {M, N, K}, kFloat); - std::vector in(M * N * K); - for (int j = 0; j < M * N * K; ++j) { - in[j] = j; - } - - for (const auto i : c10::irange(3)) { - std::vector out(M, -1.f); - - Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - LoopNest::splitWithTail(loops[i], 5); - - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - ASSERT_EQ(out[0], 4950); - } -} - -// Split a reduction axis with only a tail loop (the split loop will be size 0 -// and eliminated out). -TEST(Reductions, ReduceOverSplitTail) { - const int M = 10; - const int N = 10; - const int K = 10; - - BufHandle b("b", {M, N, K}, kFloat); - std::vector in(M * N * K); - for (int j = 0; j < M * N * K; ++j) { - in[j] = j; - } - - for (const auto i : c10::irange(3)) { - std::vector out(M, -1.f); - - Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - LoopNest::splitWithTail(loops[i], 16); - - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - ASSERT_EQ(out[0], 4950); - } -} - -// Split a reduction axis with a mask. -TEST(Reductions, ReduceSplitMask) { - const int M = 10; - const int N = 10; - const int K = 10; - - BufHandle b("b", {M, N, K}, kFloat); - std::vector in(M * N * K); - for (int j = 0; j < M * N * K; ++j) { - in[j] = j; - } - - for (const auto i : c10::irange(3)) { - std::vector out(M, -1.f); - - Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - LoopNest::splitWithMask(loops[i], 8); - - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - ASSERT_EQ(out[0], 4950); - } -} - -// Split a reduction axis cleanly not requiring a mask. -TEST(Reductions, ReduceSplitNoMask) { - const int M = 10; - const int N = 10; - const int K = 10; - BufHandle b("b", {M, N, K}, kFloat); - std::vector in(M * N * K); - for (int j = 0; j < M * N * K; ++j) { - in[j] = j; - } - - for (const auto i : c10::irange(3)) { - std::vector out(M, -1.f); - - Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - LoopNest::splitWithMask(loops[i], 5); - - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - ASSERT_EQ(out[0], 4950); - } -} - -// Split a reduction axis with all logic in the mask. -TEST(Reductions, ReduceOverSplitMask) { - const int M = 10; - const int N = 10; - const int K = 10; - - BufHandle b("b", {M, N, K}, kFloat); - std::vector in(M * N * K); - for (int j = 0; j < M * N * K; ++j) { - in[j] = j; - } - - for (const auto i : c10::irange(3)) { - std::vector out(M, -1.f); - - Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - LoopNest::splitWithMask(loops[i], 16); - - loop.prepareForCodegen(); - StmtPtr s = loop.root_stmt(); - s = IRSimplifier::simplify(s); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - ASSERT_EQ(out[0], 4950); - } -} - -// Test an rfactor when there are two ReduceOps in the graph due to a -// splitWithTail. -TEST(Reductions, ReduceSplitRfactor) { - const int M = 2; - const int N = 10; - const int K = 10; - const int SPLIT_FACTOR = 4; - - BufHandle b("b", {M, N, K}, kFloat); - std::vector in(M * N * K); - for (const auto m : c10::irange(M)) { - for (int j = 0; j < N * K; ++j) { - in[m * N * K + j] = j; - } - } - - std::vector out(M, -1.f); - - Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - LoopNest::splitWithTail(loops[2], SPLIT_FACTOR); - - auto c_body = loop.getAllWritesToBuf(c.buf())[2]; - auto all_loops = loop.getAllLoopNestsWritingToBuf(c.buf()); - ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(2).size() == 3); - LoopNest::reorderAxis(all_loops[2][1], all_loops[2][2]); - all_loops = loop.getAllLoopNestsWritingToBuf(c.buf()); - ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(2).size() == 3); - ASSERT_TRUE(loop.rfactor(c_body, all_loops[2][1])); - loop.prepareForCodegen(); - loop.simplify(); - StmtPtr s = loop.root_stmt(); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - for ([[maybe_unused]] const auto i : c10::irange(M)) { - ASSERT_EQ(out[0], 4950); - } -} - -// Test an rfactor which ends up being eliminated since the total loop size is -// smaller than the split factor. -TEST(Reductions, ReduceOverSplitRfactor) { - const int N = 10; - const int K = 10; - const int SPLIT_FACTOR = 16; - - BufHandle b("b", {N, K}, kFloat); - std::vector in(N * K); - for (int j = 0; j < N * K; ++j) { - in[j] = j; - } - - std::vector out(1, -1.f); - - Tensor c = Reduce("sum", {}, Sum(), b, {N, K}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - ForPtr i, t; - LoopNest::splitWithTail(loops[1], SPLIT_FACTOR, &i, &t); - LoopNest::reorderAxis(loops[0], i); - - auto all_loops = loop.getAllLoopNestsWritingToBuf(c.buf()); - ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(1).size() == 3); - auto c_body = loop.getAllWritesToBuf(c.buf())[1]; - ASSERT_TRUE(loop.rfactor(c_body, all_loops[1][0])); - LoopNest::reorderAxis(all_loops[1][0], all_loops[1][2]); - - loop.prepareForCodegen(); - loop.simplify(); - StmtPtr s = loop.root_stmt(); - - SimpleIREvaluator cg(s, {b, c}); - - cg.call({in, out}); - ASSERT_EQ(out[0], 4950); - - std::ostringstream oss; - oss << *cg.stmt(); - - // Check the IR to verify the rfactored reduce is eliminated. - // TODO: The alloc free should be eliminated here since it is size 0. - /* - const std::string& verification_pattern = - R"IR( -# CHECK: Allocate(tmp_buf); // dtype=float, dims=[0] -# CHECK: sum[0] = 0.f; -# CHECK: for (int n = 0; n < 10; n++) { -# CHECK: for (int k_tail = 0; k_tail < 10; k_tail++) { -# CHECK: sum[0] = (sum[0]) + (b[k_tail + 10 * n]); -# CHECK: } -# CHECK: } -# CHECK: Free(tmp_buf);)IR"; - */ - // TODO: rfactor output is not consistent yet, will fix (@nickg). - // torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Reductions, ReduceInlineReduction) { - const int M = 4; - const int N = 5; - const int K = 6; - - BufHandle a_buf("a", {M}, kFloat); - BufHandle b_buf("b", {M, N, K}, kFloat); - - Tensor x = Reduce("x", {M}, Sum(), b_buf, {N, K}); - Tensor y = Compute( - "y", {M}, [&](const VarHandle& m) { return a_buf.load(m) + x.load(m); }); - - PaddedBuffer a_v(M); - PaddedBuffer b_v(M, N, K); - - for (const auto i : c10::irange(M)) { - a_v(i) = i * i; - } - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - for (const auto k : c10::irange(K)) { - b_v(i, j, k) = j * j * k; - } - } - } - - LoopNest l1({y}, {x, y}); - // Cannot inline a reduction computation - ASSERT_FALSE(l1.computeInline(x.buf())); -} - -TEST(Reductions, ReduceInlineConsumer) { - const int M = 4; - const int N = 5; - const int K = 6; - - BufHandle a_buf("a", {M, N, K}, kFloat); - BufHandle b_buf("b", {M, N, K}, kFloat); - - Tensor x = Compute( - "x", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n, k) + b_buf.load(m, n, k); - }); - Tensor y = Reduce("y", {M}, Sum(), x, {N, K}); - - PaddedBuffer a_v(M, N, K); - PaddedBuffer b_v(M, N, K); - - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - for (const auto k : c10::irange(K)) { - a_v(i, j, k) = i * i + k; - b_v(i, j, k) = j * j + k; - } - } - } - - LoopNest l1({y}, {x, y}); - LoopNest l2(l1); - l2.computeInline(x.buf()); - - l1.prepareForCodegen(); - l2.prepareForCodegen(); - - StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); - StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt()); - - SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y}); - SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y}); - - PaddedBuffer y_1(M); - PaddedBuffer y_2(M); - - eval1(a_v, b_v, y_1); - eval2(a_v, b_v, y_2); - ExpectAllNear(y_1, y_2, 1e-5); - std::ostringstream oss1, oss2; - oss1 << *stmt1; - oss2 << *stmt2; - ASSERT_GT(oss1.str().size(), oss2.str().size()); -} - -TEST(Reductions, ReduceInlineReducerInternal) { - const int M = 4; - const int N = 5; - const int K = 6; - - BufHandle a_buf("a", {M, N, K}, kFloat); - BufHandle b_buf("b", {M, N, K}, kFloat); - - Tensor x = Compute( - "x", - {M, N, K}, - [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return a_buf.load(m, n, k) + b_buf.load(m, n, k); - }); - - Reducer minimum(ExprHandle(0.f), [&](ExprHandle a, ExprHandle b) { - return Add::make(ExprHandle(1.f), Min::make(a, b, false)); - }); - Tensor y = Reduce("y", {M}, minimum, x, {N, K}); - - PaddedBuffer a_v(M, N, K); - PaddedBuffer b_v(M, N, K); - - for (const auto i : c10::irange(M)) { - for (const auto j : c10::irange(N)) { - for (const auto k : c10::irange(K)) { - a_v(i, j, k) = i * i + k; - b_v(i, j, k) = j * j + k; - } - } - } - - LoopNest l1({y}, {x, y}); - LoopNest l2(l1); - l2.computeInline(x.buf()); - - l1.prepareForCodegen(); - l2.prepareForCodegen(); - - StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); - StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt()); - - SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y}); - SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y}); - - PaddedBuffer y_1(M); - PaddedBuffer y_2(M); - - eval1(a_v, b_v, y_1); - eval2(a_v, b_v, y_2); - ExpectAllNear(y_1, y_2, 1e-5); - std::ostringstream oss1, oss2; - oss1 << *stmt1; - oss2 << *stmt2; - ASSERT_GT(oss1.str().size(), oss2.str().size()); -} - -TEST(Reductions, ReductionCacheAccessesOperatorAxis) { - int L = 4; - int N = 3; - int M = 2; - - BufHandle a("a", {L, N, M}, kFloat); - BufHandle b("b", {L, N, M}, kFloat); - - Tensor c = Compute( - "scale", - {L, N, M}, - [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { - return b.load(l, n, m) * a.load(l, n, m); - }); - Tensor d = Reduce("sum", {L}, Sum(), c, {N, M}); - - Tensor e = Compute("scale", {L}, [&](const VarHandle& l) { - return b.load(0, 0, l) * d.load(l); - }); - - LoopNest l({e}, {c, d, e}); - LoopNest l_before(l); - l_before.prepareForCodegen(); - SimpleIREvaluator cg_before( - LoopNest::sanitizeNames(l_before.root_stmt()), {a, b, e}); - - StmtPtr d_loop = l.getLoopStmtsFor(d)[0]; - l.cacheAccesses(d.buf(), "d_local", d_loop); - l.prepareForCodegen(); - - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg_after(result, {a, b, e}); - - std::ostringstream oss; - oss << *cg_after.stmt(); - const std::string& expected_ir = - R"IR( -#CHECK: Allocate(d_local); // dtype=float, dims=[4] -#CHECK: for (int i_2 -#CHECK: d_local[i_2] = 0.f -#CHECK: for (int -#CHECK: for (int -#CHECK: d_local[i_2] = (d_local[i_2]) + (scale[ -#CHECK: } -#CHECK: } -#CHECK: } -#CHECK: for (int i_3 -#CHECK: sum[i_3] = d_local[i_3] -#CHECK: Free(d_local); -#CHECK-NOT: d_local - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); - - PaddedBuffer a_v(L, M, N, "a"); - PaddedBuffer b_v(L, M, N, "b"); - PaddedBuffer c_v(L, M, N, "c"); - PaddedBuffer d_v(L, "d"); - PaddedBuffer e_before(L, "e_before"); - PaddedBuffer e_after(L, "e_after"); - - for (const auto l : c10::irange(L)) { - for (const auto m : c10::irange(M)) { - for (const auto n : c10::irange(N)) { - a_v(l, m, n) = at::randn({1}).item().to(); - b_v(l, m, n) = at::randn({1}).item().to(); - } - } - } - - cg_before.call({a_v, b_v, e_before}); - cg_after.call({a_v, b_v, e_after}); - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - ExpectAllNear(e_before, e_after, 1e-5); -} - -TEST(Reductions, ReductionCacheAccessesOuterReduceAxis) { - int L = 4; - int N = 3; - int M = 2; - - BufHandle a("a", {L, N, M}, kFloat); - BufHandle b("b", {L, N, M}, kFloat); - - Tensor c = Compute( - "scale", - {L, N, M}, - [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { - return b.load(l, n, m) * a.load(l, n, m); - }); - Tensor d = Reduce("sum", {L}, Sum(), c, {N, M}); - - Tensor e = Compute("scale", {L}, [&](const VarHandle& l) { - return b.load(0, 0, l) * d.load(l); - }); - - LoopNest l({e}, {c, d, e}); - LoopNest l_before(l); - l_before.prepareForCodegen(); - SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e}); - - StmtPtr d_loop = l.getLoopStmtsFor(d)[1]; - l.cacheAccesses(d.buf(), "d_local", d_loop); - l.prepareForCodegen(); - - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg_after(result, {a, b, e}); - - std::ostringstream oss; - oss << *cg_after.stmt(); - const std::string& expected_ir = - R"IR( -#CHECK: Allocate(d_local); // dtype=float, dims=[1] -#CHECK: sum[i_1] = 0 -#CHECK: d_local[0] = sum[i_1] -#CHECK: for (int j_1 -#CHECK: for (int k_1 -#CHECK: d_local[0] = (d_local[0]) + (scale[ -#CHECK: } -#CHECK: } -#CHECK: sum[i_1] = d_local[0] -#CHECK: Free(d_local); -#CHECK-NOT: d_local - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); - - PaddedBuffer a_v(L, M, N, "a"); - PaddedBuffer b_v(L, M, N, "b"); - PaddedBuffer c_v(L, M, N, "c"); - PaddedBuffer d_v(L, "d"); - PaddedBuffer e_before(L, "e_before"); - PaddedBuffer e_after(L, "e_after"); - - for (const auto l : c10::irange(L)) { - for (const auto m : c10::irange(M)) { - for (const auto n : c10::irange(N)) { - a_v(l, m, n) = at::randn({1}).item().to(); - b_v(l, m, n) = at::randn({1}).item().to(); - } - } - } - - cg_before.call({a_v, b_v, e_before}); - cg_after.call({a_v, b_v, e_after}); - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - ExpectAllNear(e_before, e_after, 1e-5); -} - -TEST(Reductions, ReductionCacheAccessesInnerReduceAxis) { - int L = 4; - int N = 3; - int M = 2; - - BufHandle a("a", {L, N, M}, kFloat); - BufHandle b("b", {L, N, M}, kFloat); - - Tensor c = Compute( - "scale", - {L, N, M}, - [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { - return b.load(l, n, m) * a.load(l, n, m); - }); - Tensor d = Reduce("sum", {L}, Sum(), c, {N, M}); - - Tensor e = Compute("scale", {L}, [&](const VarHandle& l) { - return b.load(0, 0, l) * d.load(l); - }); - - LoopNest l({e}, {c, d, e}); - LoopNest l_before(l); - l_before.prepareForCodegen(); - SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e}); - - StmtPtr d_loop = l.getLoopStmtsFor(d)[2]; - l.cacheAccesses(d.buf(), "d_local", d_loop); - l.prepareForCodegen(); - - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg_after(result, {a, b, e}); - - std::ostringstream oss; - oss << *cg_after.stmt(); - const std::string& expected_ir = - R"IR( -#CHECK: Allocate(d_local); // dtype=float, dims=[1] -#CHECK: sum[i_1] = 0 -#CHECK: for (int -#CHECK: d_local[0] = 0 -#CHECK: for (int -#CHECK: d_local[0] = (d_local[0]) + (scale[ -#CHECK: } -#CHECK: sum[i_1] = (sum[i_1]) + (d_local[0]) -#CHECK: } -#CHECK: Free(d_local); -#CHECK-NOT: d_local - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); - - PaddedBuffer a_v(L, M, N, "a"); - PaddedBuffer b_v(L, M, N, "b"); - PaddedBuffer c_v(L, M, N, "c"); - PaddedBuffer d_v(L, "d"); - PaddedBuffer e_before(L, "e_before"); - PaddedBuffer e_after(L, "e_after"); - - for (const auto l : c10::irange(L)) { - for (const auto m : c10::irange(M)) { - for (const auto n : c10::irange(N)) { - a_v(l, m, n) = at::randn({1}).item().to(); - b_v(l, m, n) = at::randn({1}).item().to(); - } - } - } - - cg_before.call({a_v, b_v, e_before}); - cg_after.call({a_v, b_v, e_after}); - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) - ExpectAllNear(e_before, e_after, 1e-5); -} - -TEST(Reductions, ReductionCacheBodyAccess) { - BufHandle a("a", {24, 32, 12}, kFloat); - BufHandle b("b", {24, 32, 12}, kFloat); - - Tensor c = Compute( - "scale", - {24, 32, 12}, - [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { - return b.load(l, n, m) * a.load(l, n, m); - }); - Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12}); - - Tensor e = Compute("scale", {24}, [&](const VarHandle& l) { - return b.load(0, 0, l) * d.load(l); - }); - - LoopNest l({e}, {c, d, e}); - - StmtPtr d_loop = l.getLoopStmtsFor(d)[1]; - l.cacheAccesses(c.buf(), "scale_local", d_loop); - - l.prepareForCodegen(); - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg(result, {a, b, e}); - - std::ostringstream oss; - oss << *cg.stmt(); - const std::string& expected_ir = - R"IR( -#CHECK: Allocate(scale_local); // dtype=float, dims=[1, 32, 12] -#CHECK: for (int j_1 = 0; j_1 < 32; j_1++) { -#CHECK: for (int k_1 = 0; k_1 < 12; k_1++) { -#CHECK: scale_local[k_1 + 12 * j_1] = scale[(k_1 + 12 * j_1) + 384 * i_1]; -#CHECK: sum[i_1] = (sum[i_1]) + (scale_local[k_2 + 12 * j_2]); -#CHECK: scale_1[i_2] = (b[i_2]) * (sum[i_2]); -#CHECK: Free(scale_local); - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} - -TEST(Reductions, ReductionCacheConsumerAccess) { - BufHandle a("a", {24, 32, 12}, kFloat); - BufHandle b("b", {24, 32, 12}, kFloat); - - Tensor c = Compute( - "scale", - {24, 32, 12}, - [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { - return b.load(l, n, m) * a.load(l, n, m); - }); - Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12}); - - Tensor e = Compute("scale", {24}, [&](const VarHandle& l) { - return b.load(0, 0, l) * d.load(l); - }); - - LoopNest l({e}, {c, d, e}); - - LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4); - - StmtPtr e_loop = l.getLoopStmtsFor(e)[1]; - l.cacheAccesses(d.buf(), "sum_local", e_loop); - l.prepareForCodegen(); - - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg(result, {a, b, e}); - - std::ostringstream oss; - oss << *cg.stmt(); - const std::string& expected_ir = - R"IR( -#CHECK: Alias(sum_local,scale); -#CHECK: sum[i_1] = (sum[i_1]) + (scale[ -#CHECK: for (int j_2 = 0; j_2 < 4 -#CHECK: sum_local[j_2] = sum[j_2 + 4 * i_2]; -#CHECK: scale_1[j_3 + 4 * i_2] = (b[j_3 + 4 * i_2]) * (sum_local[j_3]); - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} - -TEST(Reductions, ReductionSplitCacheConsumerAccess) { - BufHandle a("a", {24, 32, 12}, kFloat); - BufHandle b("b", {24, 32, 12}, kFloat); - - Tensor c = Compute( - "scale", - {24, 32, 12}, - [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { - return b.load(l, n, m) * a.load(l, n, m); - }); - Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12}); - - Tensor e = Compute("scale", {24}, [&](const VarHandle& l) { - return b.load(0, 0, l) * d.load(l); - }); - - LoopNest l({e}, {c, d, e}); - - ForPtr inner; - - // Split outer reduction axis. - LoopNest::splitWithMask(l.getLoopStmtsFor(d)[0], 4, &inner); - - // Split reduction consumer. - LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4, &inner); - - l.cacheAccesses(d.buf(), "sum_local", inner); - l.prepareForCodegen(); - - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg(result, {a, b, e}); - - // reduction changes but cache does not. - std::ostringstream oss; - oss << *cg.stmt(); - const std::string& expected_ir = - R"IR( -#CHECK: Alias(sum_local,scale); -#CHECK: sum[j_1 + 4 * i_1] = (sum[j_1 + 4 * i_1]) + (scale[((l + 12 * k_1) + 1536 * i_1) + 384 * j_1]); -#CHECK: for (int i_2 = 0; i_2 < 6 -#CHECK: for (int j_2 = 0; j_2 < 4 -#CHECK: sum_local[j_2] = sum[j_2 + 4 * i_2]; -#CHECK: for (int j_3 = 0; j_3 < 4 -#CHECK: scale_1[j_3 + 4 * i_2] = (b[j_3 + 4 * i_2]) * (sum_local[j_3]); - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} - -TEST(Reductions, ReductionReorderCacheConsumerAccess) { - BufHandle a("a", {24, 32, 12}, kFloat); - BufHandle b("b", {24, 32, 12}, kFloat); - - Tensor c = Compute( - "scale", - {24, 32, 12}, - [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { - return b.load(l, n, m) * a.load(l, n, m); - }); - Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12}); - - Tensor e = Compute("scale", {24}, [&](const VarHandle& l) { - return b.load(0, 0, l) * d.load(l); - }); - - LoopNest l({e}, {c, d, e}); - - ForPtr inner; - - // reorder outer reduction axes. - auto loops = l.getLoopStmtsFor(d); - LoopNest::reorderAxis(loops[0], loops[1]); - - // Split reduction consumer. - LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4, &inner); - - l.cacheAccesses(d.buf(), "sum_local", inner); - l.prepareForCodegen(); - - StmtPtr result = - LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); - SimpleIREvaluator cg(result, {a, b, e}); - - // neither reduction body not cache changes. - std::ostringstream oss; - oss << *cg.stmt(); - const std::string& expected_ir = - R"IR( -#CHECK: sum[j_1] = (sum[j_1]) + (scale[(k_1 + 12 * i_2) + 384 * j_1]); -#CHECK: for (int i_3 = 0; i_3 < 6; -#CHECK: for (int j_2 = 0; j_2 < 4; -#CHECK: sum_local[j_2] = sum[j_2 + 4 * i_3]; -#CHECK: for (int j_3 = 0; j_3 < 4; -#CHECK: scale_1[j_3 + 4 * i_3] = (b[j_3 + 4 * i_3]) * (sum_local[j_3]); - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} - -TEST(Reductions, ReductionRfactorCacheTempOuter) { - const int M = 10; - const int N = 10; - const int K = 10; - VarHandle m("m", kInt); - VarHandle n("n", kInt); - VarHandle k("k", kInt); - - BufHandle b("B", {m, n, k}, kFloat); - std::vector in(M * N * K); - for (int j = 0; j < M * N * K; ++j) { - in[j] = j; - } - - std::vector out(1, -1.f); - - Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k}); - LoopNest loop({c}); - - std::vector loops = loop.getLoopStmtsFor(c); - LoopNest::reorderAxis(loops.at(0), loops.at(1)); - loops = loop.getLoopStmtsFor(c); - auto c_body = loop.getAllWritesToBuf(c.buf())[1]; - BufPtr rfac_buf; - ASSERT_TRUE(loop.rfactor(c_body, loops.at(0), &rfac_buf)); - loop.distributeLoop(loops.at(0)); - - auto all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf); - ASSERT_TRUE(all_loops.size() == 2 && all_loops.at(1).size() == 3); - LoopNest::reorderAxis(all_loops[1][0], all_loops[1][1]); - - all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf); - LoopNest::cacheAccesses(rfac_buf, "tmp", all_loops[1][1]); - loop.simplify(); - loop.prepareForCodegen(); - StmtPtr s = LoopNest::sanitizeNames(loop.root_stmt()); - SimpleIREvaluator cg(s, {b, c, m, n, k}); - - std::ostringstream oss; - oss << *cg.stmt(); - const std::string& expected_ir = - R"IR( -#CHECK: Allocate(sum_rfac); // dtype=float, dims=[n] -#CHECK: Allocate(tmp); // dtype=float, dims=[n] -#CHECK: for (int i_1 = 0; i_1 < m -#CHECK: for (int j = 0; j < n -#CHECK: tmp[j] = 0 -#CHECK: } -#CHECK: for (int j_1 = 0; j_1 < n -#CHECK: for (int k -#CHECK: tmp[j_1] = (tmp[j_1]) + (B[ -#CHECK: } -#CHECK: } -#CHECK: for (int j_2 = 0; j_2 < n -#CHECK: sum_rfac[j_2] = (sum_rfac[j_2]) + (tmp[j_2]); -#CHECK: } -#CHECK: Free(tmp); -#CHECK-NOT: tmp - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); - - cg.call({in, out, M, N, K}); - ASSERT_EQ(out[0], 499500); -} - -TEST(Reductions, ReductionRfactorCacheTempInner) { - const int M = 10; - const int N = 10; - const int K = 10; - VarHandle m("m", kInt); - VarHandle n("n", kInt); - VarHandle k("k", kInt); - - BufHandle b("B", {m, n, k}, kFloat); - std::vector in(M * N * K); - for (int j = 0; j < M * N * K; ++j) { - in[j] = j; - } - - std::vector out(1, -1.f); - - Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k}); - LoopNest loop({c}); - std::vector loops = loop.getLoopStmtsFor(c); - auto c_body = loop.getAllWritesToBuf(c.buf())[1]; - - LoopNest::reorderAxis(loops.at(0), loops.at(1)); - loops = loop.getLoopStmtsFor(c); - BufPtr rfac_buf; - ASSERT_TRUE(loop.rfactor(c_body, loops.at(0), &rfac_buf)); - loop.distributeLoop(loops.at(0)); - auto all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf); - ASSERT_TRUE(all_loops.size() == 2 && all_loops.at(1).size() == 3); - LoopNest::reorderAxis(all_loops[1][0], all_loops[1][1]); - - all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf); - ASSERT_TRUE(all_loops.size() == 2 && all_loops.at(1).size() == 3); - LoopNest::cacheAccesses(rfac_buf, "tmp", all_loops[1][2]); - loop.prepareForCodegen(); - loop.simplify(); - StmtPtr s = LoopNest::sanitizeNames(loop.root_stmt()); - SimpleIREvaluator cg(s, {b, c, m, n, k}); - - std::ostringstream oss; - oss << *cg.stmt(); - const std::string& expected_ir = - R"IR( -#CHECK: Allocate(sum_rfac); // dtype=float, dims=[n] -#CHECK: Allocate(tmp); // dtype=float, dims=[1] -#CHECK: for (int i_1 = 0; i_1 < m -#CHECK: for (int j = 0; j < n -#CHECK: tmp[0] = 0 -#CHECK: for (int k -#CHECK: tmp[0] = (tmp[0]) + (B[ -#CHECK: } -#CHECK: sum_rfac[j] = (sum_rfac[j]) + (tmp[0]); -#CHECK: Free(tmp); -#CHECK-NOT: tmp - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); - - cg.call({in, out, M, N, K}); - ASSERT_EQ(out[0], 499500); -} - -TEST(Reductions, ReductionVectorize) { - std::vector in_(8 * 8); - for (const auto i : c10::irange(8)) { - for (const auto j : c10::irange(8)) { - in_[i * 8 + j] = i; - } - } - std::vector out_before(8, -1.f); - std::vector out_after(8, -1.f); - - BufHandle in("in", {8, 8}, kFloat); - - Tensor tensor = Reduce("sum", {8}, Sum(), in, {8}); - LoopNest l_before({tensor}); - LoopNest l(l_before); - l_before.prepareForCodegen(); - SimpleIREvaluator cg_before(l_before.root_stmt(), {in, tensor}); - cg_before.call({in_, out_before}); - - ASSERT_TRUE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[0])); - - StmtPtr s = l.root_stmt(); - s = LoopNest::sanitizeNames(IRSimplifier::simplify(s)); - - std::ostringstream oss; - oss << *s; - const std::string& expected_ir = - R"IR( -#CHECK: sum[Ramp(0, 1, 8)] = Broadcast(0.f, 8); -#CHECK: for (int i = 0; i < 8; i++) { -#CHECK: sum[Ramp(0, 1, 8)] = ReduceOp((sum[Ramp(0, 1, 8)]) + (in[Ramp(i, 8, 8)]), reduce_args={i}); -#CHECK: } - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); - - // Vectorizing should not change result. - l.prepareForCodegen(); - s = IRSimplifier::simplify(l.root_stmt()); - SimpleIREvaluator cg_after(s, {in, tensor}); - cg_after.call({in_, out_after}); - for (const auto i : c10::irange(8)) { - ASSERT_EQ(out_before[i], out_after[i]); - } -} - -TEST(Reductions, ReductionVectorizeInner) { - BufHandle in("in", {8, 8}, kFloat); - - Tensor tensor = Reduce("sum", {8}, Sum(), in, {8}); - LoopNest l({tensor}); - - ASSERT_FALSE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[1])); -} - -TEST(Reductions, ReductionVectorizeRfactor) { - std::vector in_(8 * 8); - for (const auto i : c10::irange(8)) { - for (const auto j : c10::irange(8)) { - in_[i * 8 + j] = i; - } - } - std::vector out_before(1, -1.f); - std::vector out_after(1, -1.f); - - BufHandle in("in", {8, 8}, kFloat); - - Tensor tensor = Reduce("sum", {}, Sum(), in, {8, 8}); - - LoopNest l_before({tensor}); - LoopNest l(l_before); - l_before.prepareForCodegen(); - SimpleIREvaluator cg_before(l_before.root_stmt(), {in, tensor}); - cg_before.call({in_, out_before}); - - ASSERT_FALSE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[1])); - - // But if we rfactor this so it's not a reduce axis we can vectorize that - // loop. - std::vector loops = l.getLoopStmtsFor(tensor); - LoopNest::reorderAxis(loops[0], loops[1]); - loops = l.getLoopStmtsFor(tensor); - auto tensor_body = l.getAllWritesToBuf(tensor.buf())[1]; - BufPtr rfac_buf = nullptr; - ASSERT_TRUE(LoopNest::rfactor(tensor_body, loops.at(0), &rfac_buf)); - - LoopNest::distributeLoop(loops.at(0)); - auto rfac_loops = l.getAllLoopNestsWritingToBuf(rfac_buf); - - ASSERT_TRUE(LoopNest::vectorize(rfac_loops[1][0])); - l.simplify(); - - StmtPtr s = LoopNest::sanitizeNames(l.root_stmt()); - - std::ostringstream oss; - oss << *s; - const std::string& expected_ir = - R"IR( -#CHECK: sum = 0.f; -#CHECK: for (int i = 0; i < 8; i++) { -#CHECK: sum_rfac[i] = 0.f; -#CHECK: } -#CHECK: for (int i_1 = 0; i_1 < 8; i_1++) { -#CHECK: sum_rfac[Ramp(0, 1, 8)] = ReduceOp((sum_rfac[Ramp(0, 1, 8)]) + (in[Ramp(8 * i_1, 1, 8)]), reduce_args={i_1}); -#CHECK: } -#CHECK: for (int i_2 = 0; i_2 < 8; i_2++) { -#CHECK: sum = ReduceOp((sum) + (sum_rfac[i_2]), reduce_args={i_2}); -#CHECK: } - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); - - // Vectorizing should not change result. - l.prepareForCodegen(); - s = IRSimplifier::simplify(l.root_stmt()); - SimpleIREvaluator cg_after(s, {in, tensor}); - cg_after.call({in_, out_after}); - - ASSERT_EQ(out_before[0], out_after[0]); -} - -TEST(Reductions, InitFunction) { - constexpr int M = 32; - constexpr int N = 16; - BufHandle A("A", {M, N}, kFloat); - BufHandle B("B", {N}, kFloat); - Tensor C = Reduce( - "C", - {N}, - Sum(), - [&](const std::vector& v) { return B.load(v[0]); }, - [&](const std::vector& v) { return A.load(v[1], v[0]); }, - {M}); - LoopNest nest({C}); - nest.prepareForCodegen(); - StmtPtr s = LoopNest::sanitizeNames(IRSimplifier::simplify(nest.root_stmt())); - std::ostringstream oss; - oss << *s << "\n"; - const std::string& expected_ir = - R"IR( -#CHECK: for (int i = 0; i < 16; i++) { -#CHECK: C[i] = B[i]; -#CHECK: for (int j = 0; j < 32; j++) { -#CHECK: C[i] = (C[i]) + (A[i + 16 * j]); -#CHECK: } -#CHECK: } - )IR"; - torch::jit::testing::FileCheck().run(expected_ir, oss.str()); -} -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_registerizer.cpp b/test/cpp/tensorexpr/test_registerizer.cpp deleted file mode 100644 index 6cbd04264c32..000000000000 --- a/test/cpp/tensorexpr/test_registerizer.cpp +++ /dev/null @@ -1,3702 +0,0 @@ -#include -#include "test/cpp/tensorexpr/test_base.h" - -#include "test/cpp/tensorexpr/test_utils.h" -#include "torch/csrc/jit/tensorexpr/ir_simplifier.h" -#include "torch/csrc/jit/tensorexpr/registerizer.h" - -#include - -namespace torch { -namespace jit { -using namespace torch::jit::tensorexpr; - -// Can replace a simple scalar access with a local variable. -TEST(Registerizer, RegisterizerSimple) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))}); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + x; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * A_1 = x + A_1; - * } - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: A[ -# CHECK: A_1 = -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Won't do replacement of a loop access. -TEST(Registerizer, RegisterizerLoop) { - BufHandle a("A", {10}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))}); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * A[x] = (A[x]) + x; - * } - */ - - // No change. - stmt = registerize(stmt); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * A[x] = (A[x]) + x; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK-NOT: int -# CHECK: A[0] = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: A_ -# CHECK: A[x] = -# CHECK-NOT: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Won't replace even if the load is a fixed scalar, since the store could -// invalidate it. -TEST(Registerizer, RegisterizerLoopFixedLoad) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {x}, Add::make(Load::make(a, {0}), x))}))}); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * A[x] = (A[0]) + x; - * } - */ - - // No change. - stmt = registerize(stmt); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * A[x] = (A[0]) + x; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK-NOT: int -# CHECK: A[0] = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: A_ -# CHECK: A[x] = -# CHECK-NOT: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// We can registerize accesses that occur entirely within inner scopes, even if -// they depend on the loop var. -TEST(Registerizer, RegisterizerLoopInternal) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make({For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {x}, Add::make(Load::make(a, {x}), x)), - Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))}); - - /* - * for (int x = 0; x < 10; x++) { - * A[x] = (A[x]) + x; - * A[x] = (A[x]) + x; - * } - */ - - stmt = registerize(stmt); - - // TODO: the order of terms in addition changes and in general depends on - // some hash value. This results in unpredictable swaps of the operands from - // random changes, which is not great. Ideally, we should ensure some - // specific order (ideally, the original one). - /* - * for (int x = 0; x < 10; x++) { - * int A_1 = A[x]; - * A_1 = x + A_1; - * A_1 = x + A_1; - * A[x] = A_1; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: for (int x = 0; x < 10; x++) -# CHECK: int A_1 = A[x]; -# CHECK: A_1 = A_1 + x; -# CHECK: A_1 = A_1 + x; -# CHECK: A[x] = A_1; -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// An access can be overlapped by another read in the same Expr. In this case -// B[z] and B[y] overlap and prevent registerization of both accesses. -TEST(Registerizer, RegisterizerLoopInternalLoadOverlap) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - StmtPtr stmt = Block::make({For::make( - x, - 0, - 10, - Store::make(a, {x}, Add::make(Load::make(b, {y}), Load::make(b, {z}))))}); - stmt = IRSimplifier::simplify(stmt); - - /* - * for (int x = 0; x < 10; x++) { - * A[x] = (B[y]) + (B[z]); - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -TEST(Registerizer, RegisterizerLoopInternalRepeated) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {1}), x)), - Store::make(a, {0}, Add::make(Load::make(a, {1}), x))})), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {1}), x)), - Store::make(a, {0}, Add::make(Load::make(a, {1}), x))})) - - }); - - /* - * for (int x = 0; x < 10; x++) { - * A[0] = x + (A[1]); - * A[0] = x + (A[1]); - * } - * for (int x = 0; x < 10; x++) { - * A[0] = x + (A[1]); - * A[0] = x + (A[1]); - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[1]; - * int A_2 = A[0]; - * for (int x = 0; x < 10; x++) { - * A_2 = A_1 + x; - * A_2 = A_1 + x; - * } - * for (int x = 0; x < 10; x++) { - * A_2 = A_1 + x; - * A_2 = A_1 + x; - * } - * A[0] = A_2; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[1]; -# CHECK: int A_2 = A[0]; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK: A_2 = A_1 + x; -# CHECK: A_2 = A_1 + x; -# CHECK: } -# CHECK: for (int x = 0; x < 10; x++) -# CHECK: A_2 = A_1 + x; -# CHECK: A_2 = A_1 + x; -# CHECK: } -# CHECK-NOT: A[1] -# CHECK: A[0] = A_2; -# CHECK-NOT: A[1] -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapLoopVar) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {x}), x)), - Store::make(a, {0}, Add::make(Load::make(a, {x}), x))})), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {x}), x)), - Store::make(a, {0}, Add::make(Load::make(a, {x}), x))})) - - }); - stmt = IRSimplifier::simplify(stmt); - - /* - * for (int x = 0; x < 10; x++) { - * A[0] = (A[x]) + x; - * A[0] = (A[x]) + x; - * } - * for (int x = 0; x < 10; x++) { - * A[0] = (A[x]) + x; - * A[0] = (A[x]) + x; - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapOther) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - StmtPtr stmt = IRSimplifier::simplify(Block::make( - {For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(x, Load::make(a, {y}))), - Store::make(a, {0}, Add::make(x, Load::make(a, {y})))})), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(x, Load::make(a, {y}))), - Store::make(a, {0}, Add::make(x, Load::make(a, {y})))})) - - })); - - /* - * for (int x = 0; x < 10; x++) { - * A[0] = (A[x]) + x; - * A[0] = (A[x]) + x; - * } - * for (int x = 0; x < 10; x++) { - * A[0] = (A[x]) + x; - * A[0] = (A[x]) + x; - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -// Will registerize multiple accesses of different items of the same buffer. -TEST(Registerizer, RegisterizerMultiVar) { - BufHandle a("A", {2}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make({ - Store::make(a, {0}, 0), - Store::make(a, {1}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {0}), x)), - Store::make(a, {1}, Sub::make(Load::make(a, {1}), x))})), - }); - - /* - * A[0] = 0; - * A[1] = 0; - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + x; - * A[1] = (A[1]) - x; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 0; - * int A_2 = 0; - * for (int x = 0; x < 10; x++) { - * A_2 = x + A_2; - * A_1 = A_1 - x; - * } - * A[1] = A_2; - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: int A_2 = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: A[ -# CHECK: A_1 = -# CHECK: A_2 = -# CHECK: A[1] = A_2 -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Will registerize the valid accesses while skipping invalid replacements. -TEST(Registerizer, RegisterizerVariableLoad) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - VarHandle x2("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 0), - For::make(x, 0, 10, Store::make(b, {x}, x)), - For::make( - x2, - 0, - 10, - Block::make({Store::make( - a, {0}, Add::make(Load::make(a, {0}), Load::make(b, {x2})))}))}); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * B[x] = x; - * } - * for (int x_1 = 0; x_1 < 10; x_1++) { - * A[0] = (A[0]) + (B[x_1]); - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * B[x] = x; - * } - * for (int x_1 = 0; x_1 < 10; x_1++) { - * A_1 = A_1 + (B[x_1]); - * } - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK: B[x] = x -# CHECK: for (int x_1 = 0; x_1 < 10; x_1++) -# CHECK-NOT: A[ -# CHECK: A_1 = -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Can registerize variable accesses so long as the variable does not change. -TEST(Registerizer, RegisterizerSymbolicIndices) { - VarHandle i("i", kInt); - VarHandle N("N", kInt); - BufHandle a("A", {N}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {i}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {i}, Add::make(Load::make(a, {i}), x))}))}); - - /* - * A[i] = 0; - * for (int x = 0; x < 10; x++) { - * A[i] = (A[i]) + x; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * A_1 = x + A_1; - * } - * A[i] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: A[ -# CHECK: A_1 = -# CHECK: A[i] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Can registerize accesses dependent on multiple loop vars. -TEST(Registerizer, RegisterizerMultiLoop) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 0), - For::make( - x, - 0, - 10, - For::make( - y, - 0, - 10, - Block::make({Store::make( - a, - {0}, - Mul::make(Add::make(Load::make(a, {0}), x), y))})))}); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * for (int y = 0; y < 10; y++) { - * A[0] = x * y + (A[0]) * y; - * } - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * for (int y = 0; y < 10; y++) { - * A_1 = x * y + y * A_1; - * } - * } - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK: for (int y = 0; y < 10; y++) -# CHECK-NOT: A[ -# CHECK: A_1 = -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Can registerize correctly if scalars already exist in the program. -TEST(Registerizer, RegisterizerRepeated) { - BufHandle a("A", {2}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make({ - Store::make(a, {0}, 0), - Store::make(a, {1}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {0}), x)), - Store::make(a, {1}, Sub::make(Load::make(a, {1}), x))})), - }); - - // Registerize manually to make sure we only replace a single target. - { - registerizer::RegisterizerAnalysis analysis; - stmt->accept(&analysis); - auto candidates = analysis.getCandidates(); - ASSERT_EQ(candidates.size(), 2); - - candidates.pop_back(); - registerizer::RegisterizerReplacer replacer(candidates); - stmt = stmt->accept_mutator(&replacer); - } - - // Re-analyze and replace the second target. - { - registerizer::RegisterizerAnalysis analysis; - stmt->accept(&analysis); - auto candidates = analysis.getCandidates(); - ASSERT_EQ(candidates.size(), 1); - - registerizer::RegisterizerReplacer replacer(candidates); - stmt = stmt->accept_mutator(&replacer); - } - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: int A_1_1 = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: A[ -# CHECK: A_1 = -# CHECK: A_1_1 = -# CHECK: A[1] = A_1_1; -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Can registerize the load of A. -TEST(Registerizer, RegisterizerNoLoads) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 0), - For::make( - x, 0, 10, Block::make({Store::make(a, {0}, Add::make(x, 1))}))}); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * A[0] = x + 1; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * A_1 = x + 1; - * } - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: A[ -# CHECK: A_1 = -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Can registerize the load of A but not the store of B. -TEST(Registerizer, RegisterizerNoRepeatedStores) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(b, {x}, Add::make(Load::make(a, {0}), x))}))}); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * B[x] = (A[0]) + x; - * } - */ - - stmt = registerize(stmt); - - // TODO: its unnecessary to reorder the initializer of A[0], but it's not - // actually worse so lets not worry for now. - - /* - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * B[x] = x + A_1; - * } - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: A_ -# CHECK: B[x] = -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Won't registerize if there are multiple accesses which may overlap. -TEST(Registerizer, RegisterizerMultiVarOverlap) { - BufHandle a("A", {2}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make({ - Store::make(a, {0}, 0), - Store::make(a, {1}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {x}, Add::make(Load::make(a, {0}), x)), - Store::make(a, {x + 1}, Sub::make(Load::make(a, {1}), x))})), - }); - stmt = IRSimplifier::simplify(stmt); - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -TEST(Registerizer, RegisterizerAllocs) { - BufHandle a("A", {2}, kInt); - BufHandle c("C", {1}, kInt); - VarHandle x("x", kInt); - - BufHandle b("B", {Load::make(c, {0})}, kInt); - - StmtPtr stmt = Block::make( - {Allocate::make(b), - Store::make(a, {0}, Load::make(c, {0})), - Store::make(b, {0}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(b, {0}, Add::make(Load::make(b, {0}), x)), - Store::make(a, {0}, Load::make(c, {0}))})), - Free::make(b)}); - - /* - * Allocate(B, int, {C[0]}); - * A[0] = C[0]; - * B[0] = 0; - * for (int x = 0; x < 10; x++) { - * B[0] = (B[0]) + x; - * A[0] = C[0]; - * } - * Free(B); - */ - - stmt = registerize(stmt); - - /* - * int C_1 = C[0]; - * Allocate(B, int, {C_}); - * int A_1 = C_1; - * int B_1 = 0; - * for (int x = 0; x < 10; x++) { - * B_1 = B_1 + x; - * A_1 = C_1; - * } - * B[0] = B_1; - * A[0] = A_1; - * Free(B); - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int C_1 = C[0]; -# CHECK: Allocate(B -# CHECK: int A_1 = C_1; -# CHECK: int B_1 = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK: B_1 = -# CHECK: A_1 = C_ -# CHECK: B[0] = B_1; -# CHECK: A[0] = A_1; -# CHECK: Free(B)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Registerizer, RegisterizerNoInitializer) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make({For::make( - x, - 0, - 10, - Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))}); - - /* - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + x; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[0]; - * for (int x = 0; x < 10; x++) { - * A_1 = x + A_1; - * } - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[0]; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: A[ -# CHECK: A_1 = -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Registerizer, RegisterizerNoInitializerLoopVar) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make({For::make( - x, - 0, - 10, - Block::make({Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))}); - stmt = IRSimplifier::simplify(stmt); - - /* - * for (int x = 0; x < 10; x++) { - * A[x] = (A[x]) + x; - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -TEST(Registerizer, RegisterizerLoadThenStore) { - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make({For::make( - x, - 0, - 10, - Block::make( - {Store::make(b, {0}, Add::make(Load::make(a, {0}), x)), - Store::make(a, {0}, Load::make(b, {0}))}))}); - - /* - * for (int x = 0; x < 10; x++) { - * B[0] = (A[0]) + x; - * A[0] = B[0]; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[0]; - * int B_1 = B[0]; - * for (int x = 0; x < 10; x++) { - * B_1 = x + A_1; - * A_1 = B_1; - * } - * B[0] = B_1; - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[0]; -# CHECK: int B_1 = B[0]; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: B[ -# CHECK: B_1 = -# CHECK-NOT: A[ -# CHECK: A_1 = B_ -# CHECK: B[0] = B_ -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Registerizer, RegisterizerParallelized) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - LoopOptions loopOpts; - loopOpts.set_gpu_block_index(0); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 0), - For::make( - x, - 0, - 10, - Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}), - loopOpts)}); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + x; - * } - */ - - ASSERT_THROWS_WITH( - registerize(stmt), - "Registerization must occur after parallelism flattening"); -} - -// Should be able to registerize this since the scalar would exist before the -// branch. -TEST(Registerizer, RegisterizerConditionAfter) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make( - {Store::make(a, {x}, Load::make(b, {x})), - Store::make(c, {x}, Load::make(a, {x})), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - nullptr)}); - - /* - * A[x] = B[x]; - * C[x] = A[x]; - * if (x<5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = B[x]; - * C[x] = A_1; - * if (x<5 ? 1 : 0) { - * A_1 = A_1 + 1; - * } - * A[x] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = B[x]; -# CHECK: C[x] = A_1; -# CHECK: if ( -# CHECK: A_1 = A_1 + 1; -# CHECK: A[x] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Should be able to registerize this since the scalar exists in the same form -// after the branch and there is no overlap. -TEST(Registerizer, RegisterizerConditionBefore) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make( - {Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - nullptr), - Store::make(a, {x}, Load::make(b, {x})), - Store::make(c, {x}, Load::make(a, {x}))}); - - /* - * if (x<5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * } - * A[x] = B[x]; - * C[x] = A[x]; - */ - - stmt = registerize(stmt); - - /* - * int A_ 1 = A[x]; - * if (x<5 ? 1 : 0) { - * A_1 = A_1 + 1; - * } - * A_1 = B[x]; - * C[x] = A_1; - * A[x] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[x]; -# CHECK: if ( -# CHECK: A_1 = A_1 + 1; -# CHECK: } -# CHECK: A_1 = B[x]; -# CHECK: C[x] = A_1; -# CHECK: A[x] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Should be able to registerize this as the combination of the two above rules. -TEST(Registerizer, RegisterizerConditionInside) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make( - {Store::make(a, {x}, Load::make(b, {x})), - Store::make(c, {x}, Load::make(a, {x})), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - nullptr), - Store::make(b, {x}, Load::make(a, {x})), - Store::make(a, {x}, Load::make(c, {x}))}); - - /* - * A[x] = B[x]; - * C[x] = A[x]; - * if (x<5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * } - * B[x] = A[x]; - * A[x] = C[x]; - */ - - stmt = registerize(stmt); - - /* - * int A_1 = B[x]; - * C[x] = A_1; - * if (x<5 ? 1 : 0) { - * A_1 = A_1 + 1; - * } - * B[x] = A_1; - * A_1 = C[x]; - * A[x] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = B[x]; -# CHECK: C[x] = A_1; -# CHECK: if ( -# CHECK: A_1 = A_1 + 1; -# CHECK: } -# CHECK: B[x] = A_1; -# CHECK: A_1 = C[x]; -# CHECK: A[x] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// An example where an access is cut by an overlapping access inside a -// condition, and both sides are large enough to be registerized but cannot be -// because there is no safe place to put the initializer or finalizer. -TEST(Registerizer, RegisterizerConditionInsideOverlap1) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - StmtPtr stmt = Block::make( - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - {Store::make(a, {x}, Load::make(b, {x})), - Store::make(c, {x}, Load::make(a, {x})), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Block::make({ - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - Store::make(a, {0}, 3), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - }), - nullptr), - Store::make(b, {x}, Load::make(a, {x})), - Store::make(a, {x}, Load::make(c, {x}))}); - - /* - * A[x] = B[x]; - * C[x] = A[x]; - * if (x<5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * A[0] = 3; - * A[x] = (A[x]) + 1; - * } - * B[x] = A[x]; - * A[x] = C[x]; - */ - - // The A[0] store overlaps, A[x] cutting the region that can be registerized - // into two groups. - // Each group has 2 loads and 2 stores however, so we could registerize it, - // but the first group would need to be finalized inside the condition block, - // the second would need to be initialized inside the condition block. There's - // no safe place to put these that's visible to the other uses in the group - // and so neither registerization is possible. - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -// Same as the above, but the access group before the condition (and after the -// condition) are large enough to be registerized without needing the access -// from the loop. Registerization occurs but does not include any accesses in -// the condition, and the first group must be finalized before the Cond, the -// second initialized after it. -TEST(Registerizer, RegisterizerConditionInsideOverlap2) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - StmtPtr stmt = Block::make( - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - {Store::make(a, {x}, Load::make(b, {x})), - Store::make(a, {x}, Load::make(b, {x + 1})), - Store::make(c, {x}, Load::make(a, {x})), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Block::make({ - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - Store::make(a, {0}, 3), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - }), - nullptr), - Store::make(b, {x}, Load::make(a, {x})), - Store::make(b, {x + 1}, Load::make(a, {x})), - Store::make(a, {x}, Load::make(c, {x}))}); - - /* - * A[x] = B[x]; - * A[x] = B[x + 1]; - * C[x] = A[x]; - * if (x<5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * A[0] = 3; - * A[x] = (A[x]) + 1; - * } - * B[x] = A[x]; - * B[x + 1] = A[x]; - * A[x] = C[x]; - */ - - stmt = registerize(stmt); - - /* - * int A_1 = B[x]; // A_1 initializer - * A_1 = B[x + 1]; // - * C[x] = A_1; // - * A[x] = A_1; // A_1 finalizer - * if (x<5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * A[0] = 3; - * A[x] = (A[x]) + 1; - * } - * int A_2 = A[x]; // A_2 initializer - * B[x] = A_2; // - * B[x + 1] = A_2; // - * A_2 = C[x]; // - * A[x] = A_2; // A_2 finalizer - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = B[x]; -# CHECK: A_1 = B[x + 1]; -# CHECK: C[x] = A_1; -# CHECK: A[x] = A_1; -# CHECK: if ( -# CHECK-NOT: A_1 = A_1 + 1; -# CHECK: A[x] = (A[x] -# CHECK: A[0] = -# CHECK: A[x] = (A[x] -# CHECK: } -# CHECK: int A_2 = A[x]; -# CHECK: B[x] = A_2; -# CHECK: B[x + 1] = A_2; -# CHECK: A_2 = C[x]; -# CHECK: A[x] = A_2;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// When accesses are within conditional blocks they are not visible to the wider -// program, because we don't know if the branch would be taken and if it isn't -// the accesses in it don't need to be valid (think size checks on the index). -// In this case the accesses cannot be registerized. -TEST(Registerizer, RegisterizerConditionHidden) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make( - {Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - nullptr), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kGT), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - nullptr)}); - - /* - * if (x<5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * } - * if (x>5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -// But... if the same access is found in a non conditional scope, that means -// that that access is valid in the higher scope (or at least if its not it's -// the user's fault). It "unhides" the conditional accesses, allowing -// registerization to occur. -TEST(Registerizer, RegisterizerConditionUnhidden) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make( - {Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - nullptr), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kGT), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - nullptr)}); - - /* - * if (x<5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * } - * A[x] = (A[x]) + 1; <-- this is doing the unhiding. - * if (x>5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[x]; - * if (x<5 ? 1 : 0) { - * A_1 = A_1 + 1; - * } - * A_1 = A_1 + 1; - * if (x>5 ? 1 : 0) { - * A_1 = A_1 + 1; - * } - * A[x] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[x]; -# CHECK: if (x<5 -# CHECK: A_1 = A_1 + 1; -# CHECK: } -# CHECK: A_1 = A_1 + 1; -# CHECK: if (x>5 -# CHECK: A_1 = A_1 + 1; -# CHECK: } -# CHECK: A[x] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Can registerize a load that occurs in the condition of a Cond. -TEST(Registerizer, RegisterizerCondCondition) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make( - {Store::make(a, {x}, Load::make(b, {x})), - Store::make(c, {x}, Load::make(a, {x})), - Cond::make( - CompareSelect::make( - Load::make(a, {x}), 5, CompareSelectOperation::kLT), - Store::make(c, {x}, Add::make(Load::make(c, {x}), 1)), - nullptr)}); - - /* - * A[x] = B[x]; - * C[x] = A[x]; - * if ((A[x])<5 ? 1 : 0) { - * C[x] = (C[x]) + 1; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = B[x]; - * int C_1 = A_1; - * if (A_1<5 ? 1 : 0) { - * C_1 = C_1 + 1; - * } - * C[x] = C_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = B[x]; -# CHECK: int C_1 = A_1; -# CHECK: if (A_1<5 -# CHECK: C_1 = C_1 + 1; -# CHECK: C[x] = C_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Appearing in the condition of a Cond makes it visible to the enclosing scope, -// and so we can registerize internal usages. -TEST(Registerizer, RegisterizerCondConditionUnhidden) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make({Cond::make( - CompareSelect::make(Load::make(a, {x}), 5, CompareSelectOperation::kLT), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), - Store::make(a, {x}, Add::make(Load::make(a, {x}), 10)))}); - - /* - * if ((A[x])<5 ? 1 : 0) { - * A[x] = (A[x]) + 1; - * } else { - * A[x] = (A[x]) + 10; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[x]; - * if (A_1<5 ? 1 : 0) { - * A_1 = A_1 + 1; - * } else { - * A_1 = A_1 + 10; - * } - * A[x] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[x]; -# CHECK: if (A_1<5 -# CHECK: A_1 = A_1 + 1; -# CHECK: } else { -# CHECK: A_1 = A_1 + 10; -# CHECK: } -# CHECK: A[x] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Conditional hiding also works for IfThenElse exprs. -TEST(Registerizer, RegisterizerIfThenElseHidden) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - StmtPtr stmt = Block::make( - {Store::make( - b, - {y}, - IfThenElse::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Add::make(Load::make(a, {x}), 1), - Add::make(Load::make(a, {x + 1}), 2))), - Store::make( - b, - {y + 1}, - IfThenElse::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Add::make(Load::make(a, {x}), 1), - Add::make(Load::make(a, {x + 1}), 2)))}); - - /* - * B[y] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2); - * B[y + 1] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2); - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -// Conditional unhiding also works for IfThenElse exprs. -TEST(Registerizer, RegisterizerIfThenElseUnhidden) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - StmtPtr stmt = Block::make({ - Store::make(a, {x}, 0), - Store::make( - b, - {y}, - IfThenElse::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Add::make(Load::make(a, {x}), 1), - Add::make(Load::make(a, {x + 1}), 2))), - Store::make( - b, - {y + 1}, - IfThenElse::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Add::make(Load::make(a, {x}), 1), - Add::make(Load::make(a, {x + 1}), 2))), - }); - - /* - * A[x] = 0; - * B[y] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2); - * B[y + 1] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2); - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 0; - * B[y] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2); - * B[y + 1] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2); - * A[x] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: B[y] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2); -# CHECK: B[y + 1] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2); -# CHECK: A[x] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Nested IfThenElse exprs can't promote to higher level scopes. -TEST(Registerizer, RegisterizerIfThenElseNested) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - BufHandle d("D", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make({Store::make( - a, - {x}, - IfThenElse::make( - CompareSelect::make(x, 3, CompareSelectOperation::kLT), - IfThenElse::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Load::make(d, {x}), - Load::make(b, {x})), - IfThenElse::make( - CompareSelect::make(x, 5, CompareSelectOperation::kEQ), - Load::make(c, {x}), - Load::make(d, {x}))))}); - - /* - * A[x] = IfThenElse(x<3 ? 1 : 0, - * IfThenElse(x==2 ? 1 : 0, D[x], B[x]), - * IfThenElse(x==5 ? 1 : 0, C[x], D[x])); - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -// Cannot registerize an access completely contained within an IfThenElse -// branch, since it is not a Stmt and cannot hold variable definitions. We need -// to check that we don't promote the initializer/finalizer to the enclosing -// Block. -TEST(Registerizer, RegisterizerIfThenElseInternal) { - // Making these floats so they don't get simplified to a single access. - BufHandle a("A", {5}, kFloat); - BufHandle b("B", {5}, kFloat); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make({Store::make( - a, - {x}, - IfThenElse::make( - CompareSelect::make(x, 3, CompareSelectOperation::kLT), - Add::make(Load::make(b, {x}), Load::make(b, {x})), - Load::make(b, {x})))}); - - /* - * A[x] = IfThenElse(x<3 ? 1 : 0, (B[x]) + (B[x]), B[x]); - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); - - // If this was a Cond instead of an IfThenElse then we could registerize the - // two accesses to B[x] in the True branch. - - // Actually lets verify that. - - stmt = Block::make({Cond::make( - CompareSelect::make(x, 3, CompareSelectOperation::kLT), - Store::make(a, {x}, Add::make(Load::make(b, {x}), Load::make(b, {x}))), - Store::make(a, {x}, Load::make(b, {x})))}); - - /* - * if (x<3 ? 1 : 0) { - * A[x] = (B[x]) + (B[x]); - * } else { - * A[x] = B[x]; - * } - */ - - stmt = registerize(stmt); - - /* - * if (x<3 ? 1 : 0) { - * float B_1 = B[x]; - * A[x] = B_1 + B_1; - * } else { - * A[x] = B[x]; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK-NOT: int -# CHECK-NOT: float -# CHECK: if (x<3 -# CHECK: float B_1 = -# CHECK: A[x] = B_1 + B_1 -# CHECK: } else { -# CHECK: A[x] = B[x] -# CHECK: } -# CHECK-NOT: A[x] -# CHECK-NOT: B[x])IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Can registerize a load that occurs in the condition of an IfThenElse; -TEST(Registerizer, RegisterizerIfThenElseCondition) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make( - {Store::make(a, {x}, Load::make(a, {x})), - Store::make( - a, - {x}, - IfThenElse::make( - CompareSelect::make( - Load::make(a, {x}), 5, CompareSelectOperation::kLT), - Load::make(b, {0}), - Load::make(c, {0})))}); - - /* - * A[x] = A[x]; <---- just here so there are enough accesses to combine. - * A[x] = IfThenElse((A[x])<5 ? 1 : 0, B[0], C[0]); - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[x]; - * A_1 = A_1; - * A_1 = IfThenElse(A_1<5 ? 1 : 0, B[0], C[0]); - * A[x] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[x]; -# CHECK: A_1 = IfThenElse(A_1<5 ? 1 : 0, B[0], C[0]); -# CHECK: A[x] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Appearing in the condition of a Cond makes it visible to the enclosing scope, -// and so we can registerize internal usages. -TEST(Registerizer, RegisterizerIfThenElseConditionUnhidden) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make({Store::make( - b, - {x}, - IfThenElse::make( - CompareSelect::make( - Load::make(a, {x}), 5, CompareSelectOperation::kLT), - Add::make(Load::make(a, {x}), 1), - Add::make(Load::make(a, {x}), 10)))}); - - /* - * B[x] = IfThenElse((A[x])<5 ? 1 : 0, (A[x]) + 1, (A[x]) + 10); - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[x]; - * B[x] = IfThenElse(A_1<5 ? 1 : 0, A_1 + 1, A_1 + 10); - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[x]; -# CHECK: B[x] = IfThenElse(A_1<5 ? 1 : 0, A_1 + 1, A_1 + 10);)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Cannot promote accesses internal to IfThenElse branches even if the enclosing -// scope if conditional. -TEST(Registerizer, RegisterizerConditionBranchOnly) { - BufHandle a("A", {5}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make({For::make( - x, - 0, - 10, - Block::make({ - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Store::make( - a, - {x}, - IfThenElse::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Add::make(Load::make(a, {x}), x), - Add::make(Load::make(a, {x - 5}), x))), - Store::make( - a, - {x - 5}, - IfThenElse::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Add::make(Load::make(a, {x}), x), - Add::make(Load::make(a, {x - 5}), x)))), - }))}); - stmt = IRSimplifier::simplify(stmt); - - std::ostringstream before; - before << *stmt; - - /* for (int x = 0; x < 10; x++) { - * if (x<5 ? 1 : 0) { - * A[x] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x); - * } else { - * A[x - 5] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x); - * } - * } - */ - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -// We can registerize an IfThenElse that appears in the condition branch of a -// Cond. This is a weird but valid thing to do. -TEST(Registerizer, RegisterizerCondIfThenElse) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - BufHandle c("C", {5}, kInt); - VarHandle x("x", kInt); - - StmtPtr stmt = Block::make({Cond::make( - CompareSelect::make( - IfThenElse::make( - CompareSelect::make( - Load::make(a, {x}), 5, CompareSelectOperation::kLT), - Load::make(a, {x}), - Load::make(b, {x})), - x, - CompareSelectOperation::kEQ), - Store::make(c, {x}, Add::make(Load::make(c, {x}), 1)), - nullptr)}); - - /* - * if ((IfThenElse((A[x])<5 ? 1 : 0, A[x], B[x]))==x ? 1 : 0) { - * C[x] = (C[x]) + 1; - * } - */ - - stmt = registerize(stmt); - - // access to A can be registerized, but not B or C - - /* - * int A_1 = A[x]; - * if ((IfThenElse(A_1<5 ? 1 : 0, A_1, B[x]))==x ? 1 : 0) { - * C[x] = (C[x]) + 1; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[x]; -# CHECK: if ((IfThenElse(A_1<5 ? 1 : 0, A_1, B[x] -# CHECK: C[x] = (C[x]) + 1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Can registerize a conditional access in the RHS of a store unhidden by it's -// LHS, and hoist it out of a loop. -TEST(Registerizer, RegisterizerIfThenElseLoop) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - StmtPtr stmt = For::make( - y, - 0, - 10, - Store::make( - a, - {x}, - IfThenElse::make( - CompareSelect::make(x, 3, CompareSelectOperation::kLT), - Load::make(a, {x}), - Load::make(b, {y})))); - - /* - * for (int y = 0; y < 10; y++) { - * A[x] = IfThenElse(x<3 ? 1 : 0, A[x], B[y]); - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[x]; - * for (int y = 0; y < 10; y++) { - * A_1 = IfThenElse(x<3 ? 1 : 0, A_1, B[y]); - * } - * A[x] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[x]; -# CHECK: for ( -# CHECK: A_1 = IfThenElse(x<3 ? 1 : 0, A_1, B[y]); -# CHECK: } -# CHECK: A[x] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Cannot registerize if the RHS overlaps the access creating visibility. -TEST(Registerizer, RegisterizerIfThenElseLoopCut) { - BufHandle a("A", {5}, kInt); - BufHandle b("B", {5}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - StmtPtr stmt = Block::make({For::make( - y, - 0, - 10, - Store::make( - a, - {x}, - IfThenElse::make( - CompareSelect::make(x, 3, CompareSelectOperation::kLT), - Load::make(a, {x}), - Load::make(a, {y}))))}); - - /* - * for (int y = 0; y < 10; y++) { - * A[x] = IfThenElse(x<3 ? 1 : 0, A[x], A[y]); - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -// Simple case where an access is cut by an overlapping access later in the -// program, we can registerize up until the overlap. -TEST(Registerizer, RegisterizerPartialAfter) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {0}), x))})), - For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1})))}); - - /* - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + x; - * } - * for (int x = 1; x < 10; x++) { - * A[x] = A[x - 1]; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * A_1 = A_1 + x; - * } - * A[0] = A_1; - * for (int x = 1; x < 10; x++) { - * A[x] = A[x - 1]; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: for ( -# CHECK: A_1 = A_1 + x; -# CHECK: } -# CHECK: A[0] = A_1; -# CHECK: for ( -# CHECK: A[x] = A[x - 1]; -# CHECK: } -# CHECK-NOT: A)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// We can registerize an access which overlaps a previous access, the -// initializer must be inserted after the previous access. -TEST(Registerizer, RegisterizerPartialBefore) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))), - Store::make(a, {0}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))}); - - /* - * for (int x = 1; x < 10; x++) { - * A[x] = A[x - 1]; - * } - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + x; - * } - */ - - stmt = registerize(stmt); - - /* - * for (int x = 1; x < 10; x++) { - * A[x] = A[x - 1]; - * } - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * A_1 = A_1 + x; - * } - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK-NOT: int -# CHECK: for ( -# CHECK: A[x] = A[x - 1]; -# CHECK: } -# CHECK: int A_1 = 0; -# CHECK: for ( -# CHECK: A_1 = A_1 + x; -# CHECK: } -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// The combination of the previous two tests, an access is cut by an overlapping -// access in both directions. -TEST(Registerizer, RegisterizerPartialInside) { - BufHandle a("A", {1}, kInt); - VarHandle x1("x1", kInt); - VarHandle x2("x2", kInt); - VarHandle x3("x3", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 2), - For::make( - x1, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x1))), - For::make(x2, 1, 10, Store::make(a, {x2}, Load::make(a, {x2 - 1}))), - For::make( - x3, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x3)))}); - - /* - * A[0] = 2; - * for (int x1 = 0; x1 < 10; x1++) { - * A[0] = (A[0]) + x1; - * } - * for (int x2 = 1; x2 < 10; x2++) { - * A[x2] = A[x2 - 1]; - * } - * for (int x3 = 0; x3 < 10; x3++) { - * A[0] = (A[0]) + x3; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 2; - * for (int x1 = 0; x1 < 10; x1++) { - * A_1 = A_1 + x1; - * } - * A[0] = A_1; - * for (int x2 = 1; x2 < 10; x2++) { - * A[x2] = A[x2 - 1]; - * } - * int A_2 = A[0]; - * for (int x3 = 0; x3 < 10; x3++) { - * A_2 = A_2 + x3; - * } - * A[0] = A_2; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 2; -# CHECK: for ( -# CHECK: A_1 = A_1 + x1; -# CHECK: } -# CHECK: A[0] = A_1; -# CHECK: for ( -# CHECK: A[x2] = -# CHECK: } -# CHECK: int A_2 = A[0]; -# CHECK: for ( -# CHECK: A_2 = A_2 + x3; -# CHECK: } -# CHECK: A[0] = A_2;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// An element could be registerized program wide but is cut by a conditional -// access, we should break this into two scalars and write back to the buffer -// before the condition. -TEST(Registerizer, RegisterizerPartialCondition) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 2), - For::make( - x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x))), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Store::make(a, {x}, Load::make(a, {x - 1})), - nullptr), - For::make( - x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x)))}); - - /* - * A[0] = 2; - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + x; - * } - * if (x<5 ? 1 : 0) { - * A[x] = A[x - 1]; - * } - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + x; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 2; - * for (int x = 0; x < 10; x++) { - * A_1 = A_1 + x; - * } - * A[0] = A_1; - * if (x<5 ? 1 : 0) { - * A[x] = A[x - 1]; - * } - * int A_2 = A[0]; - * for (int x = 0; x < 10; x++) { - * A_2 = A_2 + x; - * } - * A[0] = A_2; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 2; -# CHECK: for ( -# CHECK: A_1 = A_1 + x; -# CHECK: } -# CHECK: A[0] = A_1; -# CHECK: if ( -# CHECK: A[x] = -# CHECK: } -# CHECK: int A_2 = A[0]; -# CHECK: for ( -# CHECK: A_2 = A_2 + x; -# CHECK: } -# CHECK: A[0] = A_2;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Tests case where an access is cut by an internal conditional access which -// itself is registerized. -TEST(Registerizer, RegisterizerPartialConditionInternalCut) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 1), - Store::make(a, {0}, 3), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Block::make({Store::make(a, {x}, 1), Store::make(a, {x}, 3)}), - nullptr), - Store::make(a, {0}, 4), - Store::make(a, {0}, 6)}); - - /* - * A[0] = 1; - * A[0] = 3; - * if (x<5 ? 1 : 0) { - * A[x] = 1; - * A[x] = 3; - * } - * A[0] = 4; - * A[0] = 6; - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 1; - * A_1 = 3; - * A[0] = A_1; - * if (x<5 ? 1 : 0) { - * int A_2 = 1; - * A_2 = 3; - * A[x] = A_2; - * } - * int A_3 = 4; - * A_3 = 6; - * A[0] = A_3; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 1; -# CHECK: A_1 = 3 -# CHECK: A[0] = A_1; -# CHECK: if ( -# CHECK: int A_2 = 1; -# CHECK: A_2 = 3; -# CHECK: A[x] = A_2; -# CHECK: } -# CHECK: int A_3 = 4; -# CHECK: A_3 = 6; -# CHECK: A[0] = A_3;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// First statement in condition closes outer access, but can be registerized -// with later statements. -TEST(Registerizer, RegisterizerPartialConditionInternalStart) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, 1), - Store::make(a, {0}, 3), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Block::make({Store::make(a, {x}, 1), Store::make(a, {x}, 3)}), - nullptr), - Store::make(a, {x}, 4), - Store::make(a, {x}, 6)}); - - /* - * A[0] = 1; - * A[0] = 3; - * if (x<5 ? 1 : 0) { - * A[x] = 1; - * A[x] = 3; - * } - * A[x] = 4; - * A[x] = 6; - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 1; - * A_1 = 3; - * A[0] = A_1; - * int A_2 = A[x]; <--- must read from the input here. - * if (x<5 ? 1 : 0) { - * A_2 = 1; - * A_2 = 3; - * } - * A_2 = 4; - * A_2 = 6; - * A[x] = A_2; - */ - - // TODO: I suppose we could refactor with a conditional initializer? - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 1; -# CHECK: A_1 = 3 -# CHECK: A[0] = A_1; -# CHECK: int A_2 = A[x]; -# CHECK: if ( -# CHECK: A_2 = 1; -# CHECK: A_2 = 3; -# CHECK: } -# CHECK: A_2 = 4; -# CHECK: A_2 = 6; -# CHECK: A[x] = A_2;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// An access cuts two open overlaps and creates four scalar variables. -TEST(Registerizer, RegisterizerPartialOverlapsTwo) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {1}, Load::make(a, {0})), - Store::make(a, {0}, Load::make(a, {1})), - Store::make(a, {0}, Load::make(a, {1})), - For::make(x, 1, 10, Store::make(a, {x}, x)), - Store::make(a, {1}, Load::make(a, {0})), - Store::make(a, {0}, Load::make(a, {1})), - Store::make(a, {0}, Load::make(a, {1}))}); - - /* - * A[1] = A[0]; - * A[0] = A[1]; - * A[0] = A[1]; - * for (int x = 1; x < 10; x++) { - * A[x] = x; - * } - * A[1] = A[0]; - * A[0] = A[1]; - * A[0] = A[1]; - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[0]; - * int A_2 = A_1; - * A_1 = A_2; - * A_1 = A_2; - * A[1] = A_2; - * A[0] = A_1; - * for (int x = 1; x < 10; x++) { - * A[x] = x; - * } - * int A_3 = A[0]; - * int A_4 = A_3; - * A_3 = A_4; - * A_3 = A_4; - * A[1] = A_4; - * A[0] = A_3; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[0]; -# CHECK: int A_2 = A_1; -# CHECK: A_1 = A_2; -# CHECK: A_1 = A_2; -# CHECK: A[1] = A_2; -# CHECK: A[0] = A_1; -# CHECK: for ( -# CHECK: A[x] = x; -# CHECK: } -# CHECK: int A_3 = A[0]; -# CHECK: int A_4 = A_3; -# CHECK: A_3 = A_4; -# CHECK: A_3 = A_4; -# CHECK: A[1] = A_4; -# CHECK: A[0] = A_3;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Nested blocks will automatically be flattened and do not provent -// registerization of enclosed accesses. -TEST(Registerizer, RegisterizerNestedBlocks) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), 2))}), - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {0}), 3)), - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {0}), 4))})})}); - - /* - * A[0] = (A[0]) + 1; - * { - * A[0] = (A[0]) + 2; - * } - * { - * A[0] = (A[0]) + 3; - * { - * A[0] = (A[0]) + 4; - * } - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[0]; - * A_1 = A_1 + 1; - * A_1 = A_1 + 2; - * A_1 = A_1 + 3; - * A_1 = A_1 + 4; - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[0]; -# CHECK: A_1 = A_1 + 1; -# CHECK: A_1 = A_1 + 2; -# CHECK: A_1 = A_1 + 3; -# CHECK: A_1 = A_1 + 4; -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// The access can be registerized internally to a condition, but must ensure -// that both initializer and finalizer are within the same condition. -TEST(Registerizer, RegisterizerNestedConditions) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make({Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - nullptr)}), - nullptr)}); - - /* - * if (x<5 ? 1 : 0) { - * A[0] = (A[0]) + 1; - * if (x==2 ? 1 : 0) { - * - * A[0] = (A[0]) + 1; - * } - * } - */ - - stmt = registerize(stmt); - - /* - * if (x<5 ? 1 : 0) { - * int A_1 = A[0]; - * A_1 = A_1 + 1; - * if (x==2 ? 1 : 0) { - * A_1 = A_1 + 1; - * } - * A[0] = A_1; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: if (x<5 -# CHECK: int A_1 = A[0]; -# CHECK: A_1 = A_1 + 1; -# CHECK: if (x==2 -# CHECK: A_1 = A_1 + 1; -# CHECK: } -# CHECK: A[0] = A_1; -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// If an access exists outside the scope of the condition then we can lift -// nested conditional usages into the same scalar. -TEST(Registerizer, RegisterizerNestedConditionsUnhidden) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Block::make( - {Store::make(a, {1}, 1), - Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - nullptr)}), - nullptr)}); - - /* - * A[0] = (A[0]) + 1; - * if (x<5 ? 1 : 0) { - * A[1] = 1; - * if (x==2 ? 1 : 0) { - * A[0] = (A[0]) + 1; - * } - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = A[0]; - * A_1 = A_1 + 1; - * if (x<5 ? 1 : 0) { - * A[1] = 1; - * if (x==2 ? 1 : 0) { - * A_1 = A_1 + 1; - * } - * } - * A[0] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = A[0]; -# CHECK: A_1 = A_1 + 1; -# CHECK: if (x<5 -# CHECK: A[1] = 1; -# CHECK: if (x==2 -# CHECK: A_1 = A_1 + 1; -# CHECK: A[0] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Registerizer, RegisterizerNestedConditionsHiddenFirst) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - nullptr), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Block::make({Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - nullptr)}), - nullptr)}); - - /* - * if (x==2 ? 1 : 0) { - * A[0] = (A[0]) + 1; - * } - * if (x<5 ? 1 : 0) { - * if (x==2 ? 1 : 0) { - * A[0] = (A[0]) + 1; - * } - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); - - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - stmt = registerize(stmt); -} - -TEST(Registerizer, RegisterizerNestedConditionsHiddenSecond) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Block::make({Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - nullptr)}), - nullptr), - Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - nullptr)}); - - /* - * if (x<5 ? 1 : 0) { - * if (x==2 ? 1 : 0) { - * A[0] = (A[0]) + 1; - * } - * } - * if (x==2 ? 1 : 0) { - * A[0] = (A[0]) + 1; - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); - - // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - stmt = registerize(stmt); -} - -// If an access is cut by another access internal to a condition block, it still -// cuts the access. -TEST(Registerizer, RegisterizerNestedConditionsCut) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - Block::make( - {Store::make(a, {x}, 1), - Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - nullptr)}), - nullptr)}); - - /* - * A[0] = (A[0]) + 1; - * if (x<5 ? 1 : 0) { - * A[x] = 1; - * if (x==2 ? 1 : 0) { - * - * A[0] = (A[0]) + 1; - * } - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -TEST(Registerizer, RegisterizerNestedConditionLoopHidden) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - nullptr), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(b, {x}, 0), - Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), - nullptr)}))}); - - /* - * if (x==2 ? 1 : 0) { - * A[0] = (A[0]) + 1; - * } - * for (int x = 0; x < 10; x++) { - * B[x] = 0; <-- this is only here to prevent Loop/Cond reordering. - * if (x==2 ? 1 : 0) { - * A[0] = (A[0]) + 1; - * } - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -// Three loops and four element regions, three of which should be registerized -// at different levels of the IR. -TEST(Registerizer, RegisterizerNestedConditionThreeDeep) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {4}, 0), - Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kGT), - Cond::make( - CompareSelect::make(x, 3, CompareSelectOperation::kGT), - Block::make({ - Cond::make( - CompareSelect::make(x, 4, CompareSelectOperation::kGT), - Block::make({ - Store::make( - a, {1}, Add::make(Load::make(a, {1}), 1)), - Store::make( - a, {2}, Add::make(Load::make(a, {2}), 1)), - Store::make( - a, {3}, Add::make(Load::make(a, {3}), 1)), - Store::make( - a, {4}, Add::make(Load::make(a, {4}), 1)), - Store::make( - a, {1}, Add::make(Load::make(a, {1}), 1)), - }), - nullptr), - Store::make(a, {2}, Add::make(Load::make(a, {2}), 1)), - }), - nullptr), - nullptr)}); - - /* - * A[4] = 0; - * if (x>2 ? 1 : 0) { - * if (x>3 ? 1 : 0) { - * if (x>4 ? 1 : 0) { - * A[1] = (A[1]) + 1; - * A[2] = (A[2]) + 1; - * A[3] = (A[3]) + 1; - * A[4] = (A[4]) + 1; - * A[1] = (A[1]) + 1; - * } - * A[2] = (A[2]) + 1; - * } - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 0; - * if (x>2 ? 1 : 0) { - * if (x>3 ? 1 : 0) { - * int A_3 = A[2]; - * if (x>4 ? 1 : 0) { - * int A_2 = A[1]; - * A_2 = A_2 + 1; - * A_3 = A_3 + 1; - * A[3] = (A[3]) + 1; - * A_1 = A_1 + 1; - * A_2 = A_2 + 1; - * A[1] = A_2; - * } - * A_3 = A_3 + 1; - * A[2] = A_3; - * } - * } - * A[4] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: if (x>2 ? 1 : 0) { -# CHECK: if (x>3 ? 1 : 0) { -# CHECK: int A_3 = A[2]; -# CHECK: if (x>4 ? 1 : 0) { -# CHECK: int A_2 = A[1]; -# CHECK: A_2 = A_2 + 1; -# CHECK: A_3 = A_3 + 1; -# CHECK: A[3] = (A[3]) + 1; -# CHECK: A_1 = A_1 + 1; -# CHECK: A_2 = A_2 + 1; -# CHECK: A[1] = A_2; -# CHECK: } -# CHECK: A_3 = A_3 + 1; -# CHECK: A[2] = A_3; -# CHECK: } -# CHECK: } -# CHECK: A[4] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Can replace a simple scalar access with a local variable even when that -// variable is an outer loop var. -TEST(Registerizer, RegisterizerNestedLoopSimple) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - StmtPtr stmt = Block::make({For::make( - y, - 0, - 10, - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {y}, Add::make(Load::make(a, {y}), x))})))}); - - /* - * for (int y = 0; y < 10; y++) { - * for (int x = 0; x < 10; x++) { - * A[y] = (A[y]) + x; - * } - * } - */ - - stmt = registerize(stmt); - - /* - * for (int y = 0; y < 10; y++) { - * int A_1 = A[y]; - * for (int x = 0; x < 10; x++) { - * A_1 = A_1 + x; - * } - * A[y] = A_1; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: for (int y -# CHECK: int A_1 = A[y]; -# CHECK: for (int x -# CHECK: A_1 = A_1 + x; -# CHECK: } -# CHECK: A[y] = A_1; -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Test the positive case of the hiddenAccess split, where an internal -// conditional access can be hoisted up through a loop to match an existing -// access in a higher scope and the two can be registerized. -TEST(Registerizer, RegisterizerHiddenAccessYes) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - StmtPtr stmt = Block::make({Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Block::make( - {Store::make(a, {0}, 0), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(b, {x}, 0), - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - Cond::make( - CompareSelect::make(x, 3, CompareSelectOperation::kEQ), - For::make( - y, - 0, - 10, - Store::make( - a, {0}, Add::make(Load::make(a, {0}), 1))), - nullptr)}))}), - nullptr)}); - - /* - * if (x==2 ? 1 : 0) { - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * B[x] = 0; - * if (x==3 ? 1 : 0) { - * for (int y = 0; y < 10; y++) { - * A[0] = (A[0]) + 1; - * } - * } - * } - * } - */ - - stmt = registerize(stmt); - - /* - * if (x==2 ? 1 : 0) { - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * B[x] = 0; - * if (x==3 ? 1 : 0) { - * for (int y = 0; y < 10; y++) { - * A_1 = A_1 + 1; - * } - * } - * } - * A[0] = A_1; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: if (x==2 -# CHECK: int A_1 = 0; -# CHECK: for (int x -# CHECK: B[x] = 0; -# CHECK: if (x==3 -# CHECK: for (int y -# CHECK: A_1 = A_1 + 1; -# CHECK: } -# CHECK: } -# CHECK: } -# CHECK: A[0] = A_1; -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Test the negative case of the hiddenAccess split, where the hoisted access is -// never unhidden at a higher scope and registerization occurs at the lower -// scope. -TEST(Registerizer, RegisterizerHiddenAccessNo) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - StmtPtr stmt = Block::make({Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Block::make({For::make( - x, - 0, - 10, - Block::make( - {Store::make(b, {x}, 0), - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - Cond::make( - CompareSelect::make(x, 3, CompareSelectOperation::kEQ), - For::make( - y, - 0, - 10, - Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))), - nullptr)}))}), - nullptr)}); - - /* - * if (x==2 ? 1 : 0) { - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * B[x] = 0; - * if (x==3 ? 1 : 0) { - * for (int y = 0; y < 10; y++) { - * A[0] = (A[0]) + 1; - * } - * } - * } - * } - */ - - stmt = registerize(stmt); - - /* - * if (x==2 ? 1 : 0) { - * for (int x = 0; x < 10; x++) { - * B[x] = 0; - * if (x==3 ? 1 : 0) { - * int A_1 = A[0]; - * for (int y = 0; y < 10; y++) { - * A_1 = A_1 + 1; - * } - * A[0] = A_1; - * } - * } - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: if (x==2 -# CHECK: for (int x -# CHECK: B[x] = 0; -# CHECK: if (x==3 -# CHECK: int A_1 = A[0]; -# CHECK: for (int y -# CHECK: A_1 = A_1 + 1; -# CHECK: } -# CHECK: A[0] = A_1; -# CHECK: } -# CHECK: } -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// In this case the conditional access must be hoisted by two loops, there are -// two accesses here one is unhidden and the other isn't. A[0] can be -// registerized but B[0] cannot. -TEST(Registerizer, RegisterizerHiddenAccessMultiLoop) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - StmtPtr stmt = Block::make({Cond::make( - CompareSelect::make(x, 2, CompareSelectOperation::kEQ), - Block::make( - {Store::make(a, {0}, 0), - For::make( - x, - 0, - 10, - For::make( - y, - 0, - 10, - Block::make({Cond::make( - CompareSelect::make(y, 3, CompareSelectOperation::kEQ), - Block::make( - {Store::make( - a, {0}, Add::make(Load::make(a, {0}), 1)), - Store::make( - b, {0}, Add::make(Load::make(b, {0}), 1))}), - nullptr)})))}), - nullptr)}); - - /* - * if (x==2 ? 1 : 0) { - * A[0] = 0; - * for (int x = 0; x < 10; x++) { - * for (int y = 0; y < 10; y++) { - * if (y==3 ? 1 : 0) { - * A[0] = (A[0]) + 1; - * B[0] = (B[0]) + 1; - * } - * } - * } - * } - */ - - stmt = registerize(stmt); - - /* - * if (x==2 ? 1 : 0) { - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * for (int y = 0; y < 10; y++) { - * if (y==3 ? 1 : 0) { - * A_1 = A_1 + 1; - * B[0] = (B[0]) + 1; - * } - * } - * } - * A[0] = A_1; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: if (x==2 -# CHECK: int A_1 = 0; -# CHECK: for (int x -# CHECK: for (int y -# CHECK: if (y==3 -# CHECK: A_1 = A_1 + 1; -# CHECK: B[0] = (B[0]) + 1; -# CHECK: } -# CHECK: } -# CHECK: } -# CHECK: A[0] = A_1; -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Accesses are registerized inside two conditions, but the immediate parent is -// not a condition. -TEST(Registerizer, RegisterizerTwoConditionalLoops) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - For::make( - x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))), - nullptr), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kGT), - For::make( - x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))), - nullptr)}); - - /* - * if (x<5 ? 1 : 0) { - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + 1; - * } - * } - * if (x>5 ? 1 : 0) { - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + 1; - * } - * } - */ - - stmt = registerize(stmt); - - /* - * if (x<5 ? 1 : 0) { - * int A_1 = A[0]; - * for (int x = 0; x < 10; x++) { - * A_1 = A_1 + 1; - * } - * A[0] = A_1; - * } - * if (x>5 ? 1 : 0) { - * int A_2 = A[0]; - * for (int x = 0; x < 10; x++) { - * A_2 = A_2 + 1; - * } - * A[0] = A_2; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: if (x<5 -# CHECK: int A_1 = A[0]; -# CHECK: for (int x -# CHECK: A_1 = A_1 + 1; -# CHECK: } -# CHECK: A[0] = A_1; -# CHECK: } -# CHECK: if (x>5 -# CHECK: int A_2 = A[0]; -# CHECK: for (int x -# CHECK: A_2 = A_2 + 1; -# CHECK: } -# CHECK: A[0] = A_2; -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Accesses are registerized inside two conditions, cut in the middle. -TEST(Registerizer, RegisterizerTwoConditionalLoopsCut) { - BufHandle a("A", {1}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kLT), - For::make( - x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))), - nullptr), - For::make(x, 0, 10, Store::make(a, {x}, 1)), - Cond::make( - CompareSelect::make(x, 5, CompareSelectOperation::kGT), - For::make( - x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))), - nullptr)}); - - /* - * if (x<5 ? 1 : 0) { - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + 1; - * } - * } - * for (int x = 0; x < 10; x++) { - * A[x] = 1; - * } - * if (x>5 ? 1 : 0) { - * for (int x = 0; x < 10; x++) { - * A[0] = (A[0]) + 1; - * } - * } - */ - - stmt = registerize(stmt); - - /* - * if (x<5 ? 1 : 0) { - * int A_1 = A[0]; - * for (int x = 0; x < 10; x++) { - * A_1 = A_1 + 1; - * } - * A[0] = A_1; - * } - * for (int x = 0; x < 10; x++) { - * A[x] = 1; - * } - * if (x>5 ? 1 : 0) { - * int A_2 = A[0]; - * for (int x = 0; x < 10; x++) { - * A_2 = A_2 + 1; - * } - * A[0] = A_2; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: if (x<5 -# CHECK: int A_1 = A[0]; -# CHECK: for (int x -# CHECK: A_1 = A_1 + 1; -# CHECK: } -# CHECK: A[0] = A_1; -# CHECK: } -# CHECK: for (int x -# CHECK: A[x] = 1; -# CHECK: if (x>5 -# CHECK: int A_2 = A[0]; -# CHECK: for (int x -# CHECK: A_2 = A_2 + 1; -# CHECK: } -# CHECK: A[0] = A_2; -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// references a Let var in a local scope which cannot be hoisted out of the -// loop. -TEST(Registerizer, RegisterizerLoopLetVar) { - BufHandle a("A", {10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - StmtPtr stmt = IRSimplifier::simplify(Block::make({For::make( - x, - 0, - 10, - Block::make( - {Let::make(y, 30), - Store::make(a, {y}, Add::make(x, Load::make(a, {y})))}))})); - - /* - * for (int x = 0; x < 10; x++) { - * int y = 30; - * A[y] = x + (A[y]); - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -// references a Let var in an outer scope that does not prevent hoisting the -// initializer. -TEST(Registerizer, RegisterizerLoopLetVarOuter) { - BufHandle a("A", {10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - StmtPtr stmt = Block::make( - {Let::make(y, 30), - For::make( - x, - 0, - 10, - Block::make( - {Store::make(a, {y}, Add::make(x, Load::make(a, {y})))}))}); - - /* - * int y = 30; - * for (int x = 0; x < 10; x++) { - * A[y] = x + (A[y]); - * } - */ - - stmt = registerize(stmt); - - /* - * int y = 30; - * int A_1 = A[y]; - * for (int x = 0; x < 10; x++) { - * A_1 = A_1 + x; - * } - * A[y] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int y = 30; -# CHECK: int A_1 = A[y]; -# CHECK: for (int x -# CHECK: A_1 = A_1 + x; -# CHECK: A[y] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Okay so the registerizer generally goes after index flattening, but just in -// case. Test multi index registerization. -TEST(Registerizer, RegisterizerMultiDim) { - BufHandle a("A", {3, 4, 5}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0, 1, 2}, 0), - For::make( - x, - 0, - 10, - Block::make({Store::make( - a, {0, 1, 2}, Add::make(Load::make(a, {0, 1, 2}), x))}))}); - - /* - * A[0, 1, 2] = 0; - * for (int x = 0; x < 10; x++) { - * A[0, 1, 2] = (A[0, 1, 2]) + x; - * } - */ - - stmt = registerize(stmt); - - /* - * int A_1 = 0; - * for (int x = 0; x < 10; x++) { - * A_1 = x + A_1; - * } - * A[0, 1, 2] = A_1; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: int A_1 = 0; -# CHECK: for (int x = 0; x < 10; x++) -# CHECK-NOT: A[ -# CHECK: A_1 = -# CHECK: A[0, 1, 2] = A_1;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// Won't registerize if only some dims match, but will still registerize -// distinct elements. -TEST(Registerizer, RegisterizerMultiDimPartial) { - BufHandle a("A", {3, 4, 5}, kInt); - VarHandle x("x", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0, 1, 2}, 0), - For::make( - x, - 0, - 10, - Block::make({Store::make( - a, {0, 2, 2}, Add::make(Load::make(a, {0, 1, 4}), x))}))}); - - /* - * A[0, 1, 2] = 0; - * for (int x = 0; x < 10; x++) { - * A[0, 2, 2] = (A[0, 1, 4]) + x; - * } - */ - - stmt = registerize(stmt); - - /* - * A[0, 1, 2] = 0; - * int A_1 = A[0, 1, 4]; - * int A_2 = A[0, 2, 2]; - * for (int x = 0; x < 10; x++) { - * A_2 = A_1 + x; - * } - * A[0, 2, 2] = A_2; - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: A[0, 1, 2] = 0; -# CHECK: int A_1 = A[0, 1, 4]; -# CHECK: int A_2 = A[0, 2, 2]; -# CHECK: for ( -# CHECK: A_2 = A_1 + x; -# CHECK: A[0, 2, 2] = A_2;)IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// If they could overlap across all dimensions we cannot registerize. -TEST(Registerizer, RegisterizerMultiDimOverlap) { - BufHandle a("A", {3, 4, 5}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0, 1, 2}, 0), - For::make( - x, - 0, - 10, - Block::make({Store::make( - a, {0, x, 2}, Add::make(Load::make(a, {y, 2, 2}), x))}))}); - stmt = IRSimplifier::simplify(stmt); - - /* - * A[0, 1, 2] = 0; - * for (int x = 0; x < 10; x++) { - * A[0, x, 2] = (A[y, 2, 2]) + x; - * } - */ - - std::ostringstream before; - before << *stmt; - - // No change. - stmt = registerize(stmt); - - std::ostringstream after; - after << *stmt; - - ASSERT_EQ(before.str(), after.str()); -} - -// But, if one dimension is known to be distinct they do not overlap. -TEST(Registerizer, RegisterizerMultiDimPartialOverlap) { - BufHandle a("A", {3, 4, 5}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - StmtPtr stmt = Block::make( - {Store::make(a, {0, 1, 2}, 0), - For::make( - x, - 0, - 10, - Block::make({Store::make( - a, {0, x, 2}, Add::make(Load::make(a, {y, 2, 4}), x))}))}); - - /* - * A[0, 1, 2] = 0; <---- 2nd dim overlaps with store. - * for (int x = 0; x < 10; x++) { - * A[0, x, 2] = (A[y, 2, 4]) + x; <---- 3rd dim has constant diff. - * } - */ - - stmt = registerize(stmt); - - /* - * A[0, 1, 2] = 0; - * int A_1 = A[y, 2, 4]; - * for (int x = 0; x < 10; x++) { - * A[0, x, 2] = A_1 + x; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: A[0, 1, 2] = 0; -# CHECK: int A_1 = A[y, 2, 4]; -# CHECK: for ( -# CHECK: A[0, x, 2] = A_1 + x; -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// A 3D reduction with different input dimensionality. -TEST(Registerizer, RegisterizerMultiDim3DReduction1) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10, 10}, kInt); - BufHandle c("C", {10, 10, 10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - StmtPtr stmt = For::make( - x, - 0, - 10, - For::make( - y, - 0, - 10, - For::make( - z, - 0, - 10, - Store::make( - c, - {x, y, z}, - Add::make( - Load::make(c, {x, y, z}), - Mul::make(Load::make(b, {x, y}), Load::make(a, {x}))))))); - - /* - * for (int x = 0; x < 10; x++) { - * for (int y = 0; y < 10; y++) { - * for (int z = 0; z < 10; z++) { - * C[x, y, z] = (C[x, y, z]) + (B[x, y]) * (A[x]); - * } - * } - * } - */ - - // We can registerize the A and B access since they can be hoisted before - // hitting a dependent loop var. - - stmt = registerize(stmt); - - /* - * for (int x = 0; x < 10; x++) { - * int A_1 = A[x]; - * for (int y = 0; y < 10; y++) { - * int B_1 = B[x, y]; - * for (int z = 0; z < 10; z++) { - * C[x, y, z] = A_1 * B_1 + (C[x, y, z]); - * } - * } - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: for (int x -# CHECK: int A_1 = A[x]; -# CHECK: for (int y -# CHECK: int B_1 = B[x, y]; -# CHECK: for (int z -# CHECK: C[x, y, z] = A_1 * B_1 + (C[x, y, z]); -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -// A 3D reduction with the same smaller dimensionality using different loop -// vars. -TEST(Registerizer, RegisterizerMultiDim3DReduction2) { - BufHandle a("A", {10}, kInt); - BufHandle b("B", {10}, kInt); - BufHandle c("C", {10}, kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - StmtPtr stmt = For::make( - x, - 0, - 10, - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - For::make( - y, - 0, - 10, - For::make( - z, - 0, - 10, - Store::make( - c, - {x}, - Add::make( - Load::make(c, {x}), - Mul::make(Load::make(b, {y}), Load::make(a, {x}))))))); - - /* - * for (int x = 0; x < 10; x++) { - * for (int y = 0; y < 10; y++) { - * for (int z = 0; z < 10; z++) { - * C[x] = (C[x]) + (B[y]) * (A[x]); - * } - * } - * } - */ - - // We can registerize all accesses, the A and C access can be hoisted to the - // outer loop since they depend only on it's loop var while the B can only be - // raised to the loop of y. - - stmt = registerize(stmt); - - /* - * for (int x = 0; x < 10; x++) { - * int A_1 = A[x]; - * int C_1 = C[x]; - * for (int y = 0; y < 10; y++) { - * int B_1 = B[y]; - * for (int z = 0; z < 10; z++) { - * C_1 = A_1 * B_1 + C_1; - * } - * } - * C[x] = C_1; - * } - */ - - std::ostringstream oss; - oss << *stmt; - - const std::string& verification_pattern = - R"IR( -# CHECK: for (int x -# CHECK: int A_1 = A[x]; -# CHECK: int C_1 = C[x]; -# CHECK: for (int y -# CHECK: int B_1 = B[y]; -# CHECK: for (int z -# CHECK: C_1 = A_1 * B_1 + C_1; -# CHECK: } -# CHECK: } -# CHECK: C[x] = C_1; -# CHECK: })IR"; - - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_simplify.cpp b/test/cpp/tensorexpr/test_simplify.cpp deleted file mode 100644 index 7ca2b74eaa76..000000000000 --- a/test/cpp/tensorexpr/test_simplify.cpp +++ /dev/null @@ -1,5680 +0,0 @@ -#include -#include - -#include -#include -#include -#include -#include - -#include - -namespace torch { -namespace jit { -using namespace torch::jit::tensorexpr; -using SimpleIRExprEval = ExprEval; - -TEST(Simplify, ConstantFoldSimple) { - ExprHandle a(2.0f); - ExprHandle b(3.0f); - ExprHandle f = (a + b); - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 5); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 5.f); -} - -TEST(Simplify, ConstantFoldTwoLayer) { - ExprHandle a(2.0f); - ExprHandle b(3.0f); - ExprHandle c(4.0f); - ExprHandle d(5.0f); - ExprHandle f = (a + b) - (c + d); - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), -4); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), -4.f); -} - -TEST(Simplify, ConstantFoldShifts) { - ExprHandle a(7); - ExprHandle b(2); - ExprHandle c(3); - ExprHandle f = ((a << b) << b) >> c; - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 14); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 7 << (4 - 3)); -} - -TEST(Simplify, ConstantFoldBitwise) { - ExprHandle a(59); - ExprHandle b(22); - ExprHandle c(101); - ExprHandle f = (a ^ b) & c; - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 37); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), (59 ^ 22) & 101); -} - -TEST(Simplify, ConstantFoldMultiOp) { - ExprHandle a(2.0f); - ExprHandle b(3.0f); - ExprHandle c(4.0f); - ExprHandle d(5.0f); - ExprHandle e(6.0f); - ExprHandle f(7.0f); - ExprHandle fn = ((a / e) - (c + d)) * (f / b); - - ExprHandle newF = IRSimplifier::simplify(fn); - ASSERT_NE(newF.AsNode(), nullptr); - - SimpleIRExprEval eval(newF); - SimpleIRExprEval ref(fn); - - ASSERT_EQ(eval.value(), ref.value()); -} - -TEST(Simplify, ConstantFoldMinMax) { - ExprHandle a(12.0f); - ExprHandle b(15.0f); - ExprHandle c(17.0f); - - // x = max(12, min(15, 17)). - ExprHandle minHandle = Min::make(b, c, true); - ExprHandle fn = Max::make(a, minHandle, false); - - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - ASSERT_EQ(fn.dtype().scalar_type(), ScalarType::Float); - - ExprHandle newF = IRSimplifier::simplify(fn); - ASSERT_NE(newF.AsNode(), nullptr); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 15.f); -} - -TEST(Simplify, ConstantFoldIntrinsics) { - ExprHandle a(2.0f); - ExprHandle b(3.0f); - ExprHandle c(4.0f); - ExprHandle powHandle = Intrinsics::make(kPow, a, b); - ExprHandle sinHandle = Intrinsics::make(kSin, powHandle); - ExprHandle modHandle = Intrinsics::make(kFmod, c, sinHandle); - ExprHandle logHandle = Intrinsics::make(kLog10, modHandle); - ExprHandle rndHandle = Intrinsics::make(kRound, logHandle); - ExprHandle fn = Intrinsics::make(kAbs, rndHandle); - - ExprHandle newF = IRSimplifier::simplify(fn); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 1); - - SimpleIRExprEval eval(newF); - SimpleIRExprEval ref(fn); - - ASSERT_EQ(eval.value(), ref.value()); -} - -TEST(Simplify, ConstantFoldCastToBool) { - ExprHandle f = Cast::make(kBool, IntImm::make(0)); - ExprHandle newF = IRSimplifier::simplify(f); - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), false); -} - -TEST(Simplify, ConstantFoldWithVar) { - { - VarHandle x("x", kInt); - ExprHandle body = x * (ExprHandle(2) + ExprHandle(4)); - - ExprHandle newF = IRSimplifier::simplify(body); - MulPtr root = newF.AsNode(); - ASSERT_NE(root, nullptr); - ASSERT_NE(to(root->lhs()), nullptr); - - SimpleIRExprEval eval(newF); - eval.bindVar(x, ExprHandle(3)); - ASSERT_EQ(eval.value(), 3 * (2 + 4)); - } - - { - VarHandle x("x", kFloat); - ExprHandle body = x * (ExprHandle(2.f) + ExprHandle(4.f)); - - ExprHandle newF = IRSimplifier::simplify(body); - MulPtr root = newF.AsNode(); - ASSERT_NE(root, nullptr); - ASSERT_NE(to(root->rhs()), nullptr); - - SimpleIRExprEval eval(newF); - eval.bindVar(x, ExprHandle(3.f)); - ASSERT_EQ(eval.value(), 3 * (2 + 4)); - } -} - -TEST(Simplify, ConditionalSelectFoldSimple) { - ExprHandle a(3.0f); - ExprHandle b(4.0f); - ExprHandle c(3.0f); - { - ExprHandle f = (a > b); - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 0); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 0); - } - { - ExprHandle f = (a < b); - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 1); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 1); - } - { - ExprHandle f = (a == c); - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 1); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 1); - } - { - ExprHandle f = (a != c); - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 0); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 0); - } -} - -TEST(Simplify, ConditionalSelectFoldTwoLayer) { - ExprHandle a(3.0f); - ExprHandle b(2.0f); - ExprHandle c(2.0f); - ExprHandle d(1.0f); - { - ExprHandle f = (a + b < c + d); - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 0); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 0); - } - { - ExprHandle f = (a + b > c + d); - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 1); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 1); - } - { - ExprHandle f = (a + d == b + c); - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 1); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 1); - } - { - ExprHandle f = (a + d != b + c); - - ExprHandle newF = IRSimplifier::simplify(f); - ASSERT_NE(newF.AsNode(), nullptr); - ASSERT_EQ(newF.AsNode()->value(), 0); - - SimpleIRExprEval eval(newF); - ASSERT_EQ(eval.value(), 0); - } -} - -TEST(Simplify, ConditionalSelectFoldWithVar) { - VarHandle x("x", kFloat); - ExprHandle f = x < 4.f; - - ExprHandle newF = IRSimplifier::simplify(f); - IntImmPtr folded = newF.AsNode(); - ASSERT_EQ(folded, nullptr); - - { - SimpleIRExprEval eval(newF); - eval.bindVar(x, ExprHandle(3.f)); - ASSERT_EQ(eval.value(), 1); - } - { - SimpleIRExprEval eval(newF); - eval.bindVar(x, ExprHandle(5.f)); - ASSERT_EQ(eval.value(), 0); - } -} - -TEST(Simplify, UnFoldableExpr) { - VarHandle x("x", kFloat); - VarHandle y("y", kFloat); - ExprHandle body = (ExprHandle(3) * x) + (ExprHandle(5) * y); - - ExprHandle newF = IRSimplifier::simplify(body); - AddPtr root = newF.AsNode(); - ASSERT_NE(root, nullptr); - ASSERT_EQ(to(root->lhs()), nullptr); - ASSERT_EQ(to(root->rhs()), nullptr); - - SimpleIRExprEval eval(newF); - eval.bindVar(x, ExprHandle(3.f)); - eval.bindVar(y, ExprHandle(2.f)); - ASSERT_EQ(eval.value(), 9 + 10); -} - -TEST(Simplify, HashSimple) { - VarHandle x("x", kFloat); - ExprHandle a(2.0f); - ExprHandle b(3.0f); - ExprHandle f = a + b * x; - - HashProvider hasher; - - auto hash_x = hasher.hash(x.node()); - auto hash_a = hasher.hash(a.node()); - auto hash_f = hasher.hash(f.node()); - - ASSERT_NE(hash_x, (size_t)0); - ASSERT_NE(hash_a, (size_t)0); - ASSERT_NE(hash_f, (size_t)0); - ASSERT_NE(hash_x, hash_a); - ASSERT_NE(hash_x, hash_f); - ASSERT_NE(hash_a, hash_f); -} - -TEST(Simplify, HashEquivalence) { - VarHandle x("x", kFloat); - VarHandle y("y", kFloat); - ExprHandle f = (x * y) + (x * y); - - AddPtr root = f.AsNode(); - ASSERT_NE(root, nullptr); - - HashProvider hasher; - auto hash_f = hasher.hash(f.node()); - auto hash_l = hasher.hash(root->lhs()); - auto hash_r = hasher.hash(root->rhs()); - - // Root not equal to either branch. - ASSERT_NE(hash_f, hash_l); - ASSERT_NE(hash_f, hash_r); - // but branches are equal. - ASSERT_EQ(hash_l, hash_r); - - // Still equivalent if separate. - ExprHandle a(2); - ExprHandle f2 = x + a / y; - ExprHandle b(2); - ExprHandle f3 = x + b / y; - ASSERT_EQ(hasher.hash(f2.node()), hasher.hash(f3.node())); - - // Not equivalent if different vars (even with same name). - VarHandle z("x", kFloat); - ExprHandle f4 = z + b / y; - ASSERT_NE(hasher.hash(f2.node()), hasher.hash(f4.node())); - - // Intrinsics sanity check. - ExprHandle f5 = Intrinsics::make(kSin, x) * Intrinsics::make(kCos, x); - ASSERT_NE(hasher.hash(f5.node()), (size_t)0); -} - -TEST(Simplify, HashEquivalenceRand) { - ExprHandle f = - Intrinsics::make(kRand, kFloat) + Intrinsics::make(kRand, kInt); - - AddPtr root = f.AsNode(); - ASSERT_NE(root, nullptr); - - HashProvider hasher; - auto hash_f = hasher.hash(f.node()); - auto hash_l = hasher.hash(root->lhs()); - auto hash_r = hasher.hash(root->rhs()); - - // Root not equal to either branch. - ASSERT_NE(hash_f, hash_l); - ASSERT_NE(hash_f, hash_r); - // and branches are NOT equal. - ASSERT_NE(hash_l, hash_r); -} - -TEST(Simplify, HashEquivalenceAfterFolding) { - VarHandle x("x", kFloat); - ExprHandle a(2.0f); - ExprHandle b(3.0f); - ExprHandle c(5.0f); - - ExprHandle f1 = ((a + b) * x); - ExprHandle f2 = (c * x); - - HashProvider hasher; - auto hash_l = hasher.hash(f1.node()); - auto hash_r = hasher.hash(f2.node()); - - // Root not equal to either branch, and branches not equal. - ASSERT_NE(hash_l, hash_r); - - ExprHandle ff1 = IRSimplifier::simplify(f1); - ExprHandle ff2 = IRSimplifier::simplify(f2); - - auto hash_l_n = hasher.hash(ff1.node()); - auto hash_r_n = hasher.hash(ff2.node()); - // but branches are now equal. - ASSERT_EQ(hash_l_n, hash_r_n); -} - -TEST(Simplify, HashDifferenceTypes) { - HashProvider hasher; - std::vector immediates; - - immediates.push_back(alloc(1)); - immediates.push_back(alloc(1)); - immediates.push_back(alloc(1)); - // NOLINTNEXTLINE(modernize-use-bool-literals) - immediates.push_back(alloc(1)); - immediates.push_back(alloc(1)); - immediates.push_back(alloc(1)); - immediates.push_back(alloc(1)); - immediates.push_back(alloc(1)); - immediates.push_back(alloc(1)); - - // Immediates of different types are not equal. - for (unsigned int i = 0; i < immediates.size(); ++i) { - for (unsigned int j = i + 1; j < immediates.size(); ++j) { - ASSERT_NE(hasher.hash(immediates[i]), hasher.hash(immediates[j])); - } - } - - // But coerced immediates are if they are the same type: - ExprHandle f1 = ExprHandle(2.f) + CharImm::make(1); - ExprHandle f2 = Cast::make(kFloat, IntImm::make(3)); - - ExprHandle ff1 = IRSimplifier::simplify(f1); - ExprHandle ff2 = IRSimplifier::simplify(f2); - - ASSERT_EQ(hasher.hash(ff1.node()), hasher.hash(ff2.node())); -} - -TEST(Simplify, HashLargeExpression) { - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - VarHandle i("i", kInt); - auto memcpy_stmt = For::make( - i, - 0, - N, - Store::make( - c, - {i}, - CompareSelect::make( - Load::make(a, {i}), - Load::make(b, {i}), - CompareSelectOperation::kEQ))); - - BufHandle d("D", {1}, kInt); - BufHandle e("E", {1}, kInt); - auto store_ramp_stmt = Store::make( - e, {Ramp::make(0, 1, 4)}, Load::make(d, {Ramp::make(0, 1, 4)})); - - auto if_stmt = Cond::make( - CompareSelect::make( - Load::make(a, {i}), Load::make(b, {i}), CompareSelectOperation::kGE), - memcpy_stmt, - store_ramp_stmt); - - HashProvider hasher; - auto hash_r = hasher.hash(if_stmt); - // We should not have to do any more work. - ASSERT_TRUE(hasher.cachedHash(memcpy_stmt)); - auto hash_t = hasher.hash(memcpy_stmt); - ASSERT_TRUE(hasher.cachedHash(store_ramp_stmt)); - auto hash_f = hasher.hash(store_ramp_stmt); - - // Root not equal to either branch, and branches not equal. - ASSERT_NE(hash_r, hash_t); - ASSERT_NE(hash_r, hash_f); - ASSERT_NE(hash_t, hash_f); -} - -TEST(Simplify, HashForLoopOptions) { - constexpr int N = 1024; - BufHandle a("A", {N}, kInt); - BufHandle b("B", {N}, kInt); - BufHandle c("C", {N}, kInt); - VarHandle i("i", kInt); - auto for_stmt = For::make( - i, - 0, - N, - Store::make( - c, - {i}, - CompareSelect::make( - Load::make(a, {i}), - Load::make(b, {i}), - CompareSelectOperation::kEQ))); - - HashProvider hasher; - auto hash_before = hasher.hash(for_stmt); - hasher.clearCache(); - - for_stmt->set_gpu_block_index(LoopOptions::IDX_X); - auto hash_block_idx = hasher.hash(for_stmt); - hasher.clearCache(); - - ASSERT_NE(hash_before, hash_block_idx); - - for_stmt->set_gpu_block_index(LoopOptions::IDX_UNSET); - auto hash_reset = hasher.hash(for_stmt); - hasher.clearCache(); - - ASSERT_EQ(hash_before, hash_reset); - for_stmt->set_gpu_thread_index(LoopOptions::IDX_X); - auto hash_thread_idx = hasher.hash(for_stmt); - - ASSERT_NE(hash_before, hash_thread_idx); - ASSERT_NE(hash_block_idx, hash_thread_idx); -} - -/// (2 + x) + 4 => x + 6 -TEST(Simplify, SimplifyAdd) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - VarHandle m("m", kInt); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - VarHandle n("n", kInt); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - VarHandle n_1("n_1", kInt); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - ExprHandle body = (ExprHandle(2) + x) + ExprHandle(4); - - ExprHandle simplified = IRSimplifier::simplify(body); - AddPtr root = simplified.AsNode(); - ASSERT_NE(root, nullptr); - VarPtr lhs = to(root->lhs()); - ASSERT_NE(lhs, nullptr); - ASSERT_EQ(lhs->name_hint(), "x"); - IntImmPtr rhs = to(root->rhs()); - ASSERT_NE(rhs, nullptr); - ASSERT_EQ(rhs->value(), 6.f); -} - -/// (2 - x) - 4 => -2 - x -TEST(Simplify, SimplifySub) { - VarHandle x("x", kInt); - ExprHandle body = (ExprHandle(2) - x) - ExprHandle(4); - - ExprHandle simplified = IRSimplifier::simplify(body); - SubPtr root = simplified.AsNode(); - ASSERT_NE(root, nullptr); - IntImmPtr lhs = to(root->lhs()); - ASSERT_NE(lhs, nullptr); - ASSERT_EQ(lhs->value(), -2.f); - VarPtr rhs = to(root->rhs()); - ASSERT_NE(rhs, nullptr); - ASSERT_EQ(rhs->name_hint(), "x"); -} - -/// 2 * (1 - x) - 4 => 2 * (-3 - x) -TEST(Simplify, SimplifyMultiLayer) { - VarHandle x("x", kInt); - ExprHandle body = ExprHandle(2) * ((ExprHandle(1) - x) - ExprHandle(4)); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); - IS_IMM_WITH_VAL(Int, sub->lhs(), -3); - IS_VAR_WITH_NAME(sub->rhs(), "x"); -} - -/// 2 * (3 * x) - (x * 4) => 2 * x -TEST(Simplify, SimplifyMultiTerm) { - VarHandle x("x", kInt); - ExprHandle body = - (ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4))); - - ExprHandle simplified = IRSimplifier::simplify(body); - MulPtr root = simplified.AsNode(); - ASSERT_NE(root, nullptr); - IntImmPtr lhs = to(root->lhs()); - ASSERT_NE(lhs, nullptr); - ASSERT_EQ(lhs->value(), 2); - VarPtr rhs = to(root->rhs()); - ASSERT_NE(rhs, nullptr); - ASSERT_EQ(rhs->name_hint(), "x"); -} - -/// 2 * (3 * (long)x) - (x * 4) => 2 * x -TEST(Simplify, SimplifyCasts) { - VarHandle x("x", kLong); - ExprHandle body = - (ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4))); - - ExprHandle simplified = IRSimplifier::simplify(body); - MulPtr root = simplified.AsNode(); - ASSERT_NE(root, nullptr); - LongImmPtr lhs = to(root->lhs()); - ASSERT_NE(lhs, nullptr); - ASSERT_EQ(lhs->value(), 2); - VarPtr rhs = to(root->rhs()); - ASSERT_NE(rhs, nullptr); - ASSERT_EQ(rhs->name_hint(), "x"); -} - -/// (x + 0) * 1 => x -TEST(Simplify, SimplifyEliminatesNoOps) { - VarHandle x("x", kInt); - ExprHandle body = (x + ExprHandle(0)) * 1; - - ExprHandle simplified = IRSimplifier::simplify(body); - VarPtr root = simplified.AsNode(); - ASSERT_NE(root, nullptr); - ASSERT_EQ(root->name_hint(), "x"); -} - -/// Cannot simplify this. -TEST(Simplify, SimplifyMultiVar) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = x * 24 + y * 34; - - ExprHandle simplified = IRSimplifier::simplify(body); - - AddPtr root = simplified.AsNode(); - ASSERT_NE(root, nullptr); - MulPtr lhs = to(root->lhs()); - ASSERT_NE(lhs, nullptr); - VarPtr varX = to(lhs->rhs()); - ASSERT_NE(varX, nullptr); - ASSERT_EQ(varX->name_hint(), "x"); - MulPtr rhs = to(root->rhs()); - ASSERT_NE(rhs, nullptr); - VarPtr varY = to(rhs->rhs()); - ASSERT_NE(varY, nullptr); - ASSERT_EQ(varY->name_hint(), "y"); -} - -// x + 2 + y => x + y + 2 -TEST(Simplify, DISABLED_SimplifyReorderings) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = x + 2 + y; - ExprHandle simplified = IRSimplifier::simplify(body); - - AddPtr root = simplified.AsNode(); - ASSERT_NE(root, nullptr); - - IS_NODE_WITH_NAME(Add, root->lhs(), rhs); - IS_VAR_WITH_NAME(rhs->lhs(), "x"); - IS_VAR_WITH_NAME(rhs->rhs(), "y"); - IS_IMM_WITH_VAL(Int, root->rhs(), 2); -} - -/// y + x * 0 => y -TEST(Simplify, SimplifyEliminatesVar) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = y + x * ExprHandle(0); - - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "y"); -} - -TEST(Simplify, SimplifyAdds) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - { - // (x + y) + (x + y) => 2 * (x + y) - ExprHandle body = (x + y) + (x + y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), root); - IS_IMM_WITH_VAL(Int, root->lhs(), 2); - IS_NODE_WITH_NAME(Add, root->rhs(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_VAR_WITH_NAME(add->rhs(), "y"); - } - - { - // (x * y) + (x * y) => 2 * (x * y) - ExprHandle body = (x * y) + (x * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), root); - IS_IMM_WITH_VAL(Int, root->lhs(), 2); - IS_NODE_WITH_NAME(Mul, root->rhs(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // (x - y) + (x - y) => 2 * (x - y) - ExprHandle body = (x - y) + (x - y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - - IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs); - IS_VAR_WITH_NAME(rhs->lhs(), "x"); - IS_VAR_WITH_NAME(rhs->rhs(), "y"); - } - - { - // (x + x + x + x) => 4 * x - ExprHandle body = (x + x + x + x); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), root); - IS_IMM_WITH_VAL(Int, root->lhs(), 4); - IS_VAR_WITH_NAME(root->rhs(), "x"); - } - - { - // (x + 0) => x. - ExprHandle body = x + 0; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // (x + 0.f) => float(x). - ExprHandle body = x + 0.f; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Cast, simplified.node(), cast); - ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); - IS_VAR_WITH_NAME(cast->src_value(), "x"); - } -} - -TEST(Simplify, SimplifyMuls) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - { - // (x + y) * (x + y) => (x + y) * (x + y) - // We don't attempt to simplify multiplication of polynomials since the - // result is only very rarely more efficient. - ExprHandle body = (x + y) * (x + y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_NODE_WITH_NAME(Add, mul->lhs(), lhs); - IS_VAR_WITH_NAME(lhs->lhs(), "x"); - IS_VAR_WITH_NAME(lhs->rhs(), "y"); - IS_NODE_WITH_NAME(Add, mul->rhs(), rhs); - IS_VAR_WITH_NAME(rhs->lhs(), "x"); - IS_VAR_WITH_NAME(rhs->rhs(), "y"); - } - - { - // x * y * x * y => x * x * y * y - // These get reordered only. - ExprHandle body = x * y * x * y; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul1); - IS_NODE_WITH_NAME(Mul, mul1->lhs(), mul2); - IS_NODE_WITH_NAME(Mul, mul2->lhs(), mul3); - IS_VAR_WITH_NAME(mul1->rhs(), "y"); - IS_VAR_WITH_NAME(mul2->rhs(), "y"); - IS_VAR_WITH_NAME(mul3->lhs(), "x"); - IS_VAR_WITH_NAME(mul3->rhs(), "x"); - } - - { - // 1 * (x * 1) => x - // Ones cancel cleanly. - ExprHandle body = ExprHandle(1) * (x * ExprHandle(1)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // 1.f * (x * 1.f) => x - // Even float ones cancel cleanly, but carry their type. - ExprHandle body = ExprHandle(1.f) * (x * ExprHandle(1.f)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Cast, simplified.node(), cast); - ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); - IS_VAR_WITH_NAME(cast->src_value(), "x"); - } - - { - // 1 * (x * 1.f) => x - // One float is enough to cast the expr. - ExprHandle body = ExprHandle(1) * (x * ExprHandle(1.f)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Cast, simplified.node(), cast); - ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); - IS_VAR_WITH_NAME(cast->src_value(), "x"); - } - - { - // 1 * (x * 0) => 0 - // Zeroes are eliminated. - ExprHandle body = ExprHandle(1) * (x * ExprHandle(0)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // 1 * (x * 0) => 0 - // But not for Float since nan * 0 = nan. - ExprHandle body = ExprHandle(1.f) * (x * ExprHandle(0.f)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_NODE_WITH_NAME(Cast, mul->lhs(), cast); - ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); - IS_VAR_WITH_NAME(cast->src_value(), "x"); - IS_IMM_WITH_VAL(Float, mul->rhs(), 0.0); - } - - { - // (x - y) * (x - y) => (x - y) * (x - y) - // As with Add we don't attempt simplification of this. - ExprHandle body = (x - y) * (x - y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_NODE_WITH_NAME(Sub, mul->lhs(), lhs); - IS_VAR_WITH_NAME(lhs->lhs(), "x"); - IS_VAR_WITH_NAME(lhs->rhs(), "y"); - IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs); - IS_VAR_WITH_NAME(rhs->lhs(), "x"); - IS_VAR_WITH_NAME(rhs->rhs(), "y"); - } - - { - // (x + y) * (x - y) => (x + y) * (x - y) - // Don't simplify with different ops on each side. - ExprHandle body = (x + y) * (x - y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_NODE_WITH_NAME(Add, mul->lhs(), lhs); - IS_VAR_WITH_NAME(lhs->lhs(), "x"); - IS_VAR_WITH_NAME(lhs->rhs(), "y"); - IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs); - IS_VAR_WITH_NAME(rhs->lhs(), "x"); - IS_VAR_WITH_NAME(rhs->rhs(), "y"); - } - - { - // Multiply a polynomial by a term. - // - term with no scalar, poly with non-identity scalar. - // x * (y + 1) => x + x * y - ExprHandle body = x * (y + ExprHandle(1)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_NODE_WITH_NAME(Mul, add->rhs(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // Multiply a polynomial by a term. - // - term with identity scalar, poly with non-identity scalar. - // (x * 1) * (y + 1) => x + x * y - ExprHandle body = (x * ExprHandle(1)) * (y + ExprHandle(1)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_NODE_WITH_NAME(Mul, add->rhs(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // Multiply a polynomial by a term. - // - term with non-identity scalar, poly with non-identity scalar. - // (x * 2) * (y + 1) => 2 * (x + x * y) - ExprHandle body = (x * ExprHandle(2)) * (y + ExprHandle(1)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_NODE_WITH_NAME(Add, mul->rhs(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_NODE_WITH_NAME(Mul, add->rhs(), mul2); - IS_VAR_WITH_NAME(mul2->lhs(), "x"); - IS_VAR_WITH_NAME(mul2->rhs(), "y"); - } - - { - // Multiply a polynomial by a term. - // - term with non-identity scalar, poly with identity scalar. - // (x * 2) * (y + 0) => 2 * (x * y) - ExprHandle body = (x * ExprHandle(2)) * (y + ExprHandle(0)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_NODE_WITH_NAME(Mul, mul->rhs(), mul2); - IS_VAR_WITH_NAME(mul2->lhs(), "x"); - IS_VAR_WITH_NAME(mul2->rhs(), "y"); - } - - { - // Multiply a polynomial by a term. - // - term with identity scalar, poly with identity scalar. - // (x * 1) * (y + 0) => x * y - ExprHandle body = (x * ExprHandle(1)) * (y + ExprHandle(0)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // Multiply a polynomial by a term. - // - term with no scalar, poly with identity scalar. - // x * (y + 0) => x * y - ExprHandle body = x * (y + ExprHandle(0)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } -} - -// Sub an expr from itself will result in zero. -TEST(Simplify, SimplifySubs) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - { - // (x + y) - (x + y) => 0 - ExprHandle body = (x + y) - (x + y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // (x * y) - (x * y) => 0 - ExprHandle body = (x * y) - (x * y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // (x - y) - (x - y) => 0 - ExprHandle body = (x - y) - (x - y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // (x + y) - 2 * (x + y) => -1 * x - y - ExprHandle body = (x + y) - ExprHandle(2) * (x + y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_NODE_WITH_NAME(Mul, sub->lhs(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), -1); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - IS_VAR_WITH_NAME(sub->rhs(), "y"); - } - - { - // (x + y) - y => x - ExprHandle body = (x + y) - y; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // (x - 0) => x. - ExprHandle body = x - 0; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // (x - 0.f) => x. - // Simple enough to cancel in float. - ExprHandle body = x - ExprHandle(0.f); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Cast, simplified.node(), cast); - ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); - IS_VAR_WITH_NAME(cast->src_value(), "x"); - } - - { - // (x - (float)(y - y)) => x. - ExprHandle body = x - Cast::make(kFloat, y - y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Cast, simplified.node(), cast); - ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); - IS_VAR_WITH_NAME(cast->src_value(), "x"); - } - - { - // (x - y) - y => x - 2 * y - ExprHandle body = (x - y) - y; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_VAR_WITH_NAME(sub->lhs(), "x"); - IS_NODE_WITH_NAME(Mul, sub->rhs(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // 2 * x - x => x - ExprHandle body = (ExprHandle(2) * x) - x; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // x - 2 * x = -1 * x - // We don't have a unary negate, but this could be 0 -x I guess? - ExprHandle body = x - (ExprHandle(2) * x); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - - IS_IMM_WITH_VAL(Int, mul->lhs(), -1); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - } - - { - // (x + y + 5) * (x - x) => 0 - // Cancelling out one side of Mul cancels both. - ExprHandle body = (x + y + 5) * (x - x); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // Cancel out opaque modulus. - ExprHandle body = (x % y + 2) - (x % y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 2); - } - - { - // Cancel out opaque modulus with a bit more going on. - ExprHandle body = (x % y + (x * 2 - x - y * 0) - x + 2) - (x % y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 2); - } - - { - // Sub where result is negative. - ExprHandle body = x - (x + 1); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), -1); - } - - { - // Sub where result is positive due to negative scalar on RHS. - ExprHandle body = x - (x - 1); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 1); - } - - { - // Term - Polynomial sub where RHS must be negated. - ExprHandle body = (x * 2) - (x * 2 + 1); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), -1); - } - - { - // Term - Polynomial sub where the result is a Term. - ExprHandle body = (y * x * 2) - (x * y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // Term - Polynomial sub where the result is a Polynomial. - ExprHandle body = (x * 2) - (x + 1); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - - IS_VAR_WITH_NAME(sub->lhs(), "x"); - IS_IMM_WITH_VAL(Int, sub->rhs(), 1); - } -} - -TEST(Simplify, SimplifyDiv) { - VarHandle x("x", kInt); - - { - ExprHandle body = ExprHandle(0) / x; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - ExprHandle body = x / 1; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_VAR_WITH_NAME(simplified.node(), "x"); - } -} - -TEST(Simplify, SimplifyDivWithLoopContext0) { - // Stmt to simplify: - // for (int i = 0; i < 100; i++) { - // A[i] = i / 100; - //} - VarHandle i("i", kInt); - BufHandle a_buf("A", {100}, kInt); - auto for_stmt = For::make(i, 0, 100, Store::make(a_buf, {i}, (i / 100))); - - const StmtPtr simplified = IRSimplifier::simplify(for_stmt); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = 0; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyDivWithLoopContext1) { - // Stmt to simplify: - // for (const auto i : c10::irange(6)) { - // A[i] = (i + 24) / 6; - //} - VarHandle i("i", kInt); - BufHandle a_buf("A", {6}, kInt); - auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) / 6)); - - const StmtPtr simplified = IRSimplifier::simplify(for_stmt); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = 4; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyDivWithLoopContext2) { - // Stmt to simplify: - // for (const auto i : c10::irange(5)) { - // A[i] = (i + 25) / 6; - //} - VarHandle i("i", kInt); - BufHandle a_buf("A", {5}, kInt); - auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + 25) / 6)); - - const StmtPtr simplified = IRSimplifier::simplify(for_stmt); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = 4; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyDivWithLoopContext3) { - // Stmt to simplify: - // for (const auto i : c10::irange(6)) { - // A[i] = (i + 24) / (-6); - //} - VarHandle i("i", kInt); - BufHandle a_buf("A", {6}, kInt); - auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) / (-6))); - - const StmtPtr simplified = IRSimplifier::simplify(for_stmt); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NOT: A[i] = -4; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyDivWithLoopContext4) { - // Stmt to simplify: - // for (const auto i : c10::irange(5)) { - // A[i] = (i - 5) / 6; - //} - VarHandle i("i", kInt); - BufHandle a_buf("A", {5}, kInt); - auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + (-5)) / 6)); - - const StmtPtr simplified = IRSimplifier::simplify(for_stmt); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NOT: A[i] = 0; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyDivWithLoopContext5) { - // Stmt to simplify: - // for (const auto i : c10::irange(6)) { - // for (const auto j : c10::irange(10)) { - // A[i, j] = (i + 6*j) / 6; - // } - //} - VarHandle i("i", kInt); - VarHandle j("j", kInt); - BufHandle a_buf("A", {6, 10}, kInt); - auto for_j = For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) / 6)); - auto for_i = For::make(i, 0, 6, for_j); - - const StmtPtr simplified = IRSimplifier::simplify(for_i); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK: for (int j -# CHECK-NEXT: A[i, j] = j; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyDivWithLoopContext6) { - // Stmt to simplify: - // for (const auto i : c10::irange(6)) { - // for (int j = -1; j < 9; j++) { - // A[i, j+1] = (i + 6*j) / 6; - // } - //} - VarHandle i("i", kInt); - VarHandle j("j", kInt); - BufHandle a_buf("A", {6, 10}, kInt); - auto for_j = - For::make(j, -1, 9, Store::make(a_buf, {i, j + 1}, (i + j * 6) / 6)); - auto for_i = For::make(i, 0, 6, for_j); - - const StmtPtr simplified = IRSimplifier::simplify(for_i); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK: for (int j -# CHECK-NOT: A[i, j] = j; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyDivWithLoopContext7) { - // Stmt to simplify: - // for (const auto i : c10::irange(6)) { - // for (const auto j : c10::irange(10)) { - // A[i, j] = (i + 6*j) / (-6); - // } - //} - VarHandle i("i", kInt); - VarHandle j("j", kInt); - BufHandle a_buf("A", {6, 10}, kInt); - auto for_j = - For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) / (-6))); - auto for_i = For::make(i, 0, 6, for_j); - - const StmtPtr simplified = IRSimplifier::simplify(for_i); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK: for (int j -# CHECK-NOT: A[i, j] = -j; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyModWithLoopContext0) { - // Stmt to simplify: - // for (const auto i : c10::irange(100)) { - // A[i] = i % 100; - //} - VarHandle i("i", kInt); - BufHandle a_buf("A", {100}, kInt); - auto for_stmt = For::make(i, 0, 100, Store::make(a_buf, {i}, (i % 100))); - - const StmtPtr simplified = IRSimplifier::simplify(for_stmt); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = i; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyModWithLoopContext1) { - // Stmt to simplify: - // for (const auto i : c10::irange(6)) { - // A[i] = (i + 24) % 6; - //} - VarHandle i("i", kInt); - BufHandle a_buf("A", {6}, kInt); - auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) % 6)); - - const StmtPtr simplified = IRSimplifier::simplify(for_stmt); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = i; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyModWithLoopContext2) { - // Stmt to simplify: - // for (const auto i : c10::irange(5)) { - // A[i] = (i + 25) % 6; - //} - VarHandle i("i", kInt); - BufHandle a_buf("A", {5}, kInt); - auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + 25) % 6)); - - const StmtPtr simplified = IRSimplifier::simplify(for_stmt); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NEXT: A[i] = i + 1; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyModWithLoopContext3) { - // Stmt to simplify: - // for (const auto i : c10::irange(6)) { - // A[i] = (i + 24) % (-6); - //} - VarHandle i("i", kInt); - BufHandle a_buf("A", {6}, kInt); - auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) % (-6))); - - const StmtPtr simplified = IRSimplifier::simplify(for_stmt); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NOT: A[i] = i; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyModWithLoopContext4) { - // Stmt to simplify: - // for (const auto i : c10::irange(5)) { - // A[i] = (i - 5) % 6; - //} - VarHandle i("i", kInt); - BufHandle a_buf("A", {5}, kInt); - auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + (-5)) % 6)); - - const StmtPtr simplified = IRSimplifier::simplify(for_stmt); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK-NOT: A[i] = i - 5; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyModWithLoopContext5) { - // Stmt to simplify: - // for (const auto i : c10::irange(6)) { - // for (const auto j : c10::irange(10)) { - // A[i, j] = (i + 6*j) % 6; - // } - //} - VarHandle i("i", kInt); - VarHandle j("j", kInt); - BufHandle a_buf("A", {6, 10}, kInt); - auto for_j = For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) % 6)); - auto for_i = For::make(i, 0, 6, for_j); - - const StmtPtr simplified = IRSimplifier::simplify(for_i); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK: for (int j -# CHECK-NEXT: A[i, j] = i; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyModWithLoopContext6) { - // Stmt to simplify: - // for (const auto i : c10::irange(6)) { - // for (int j = -1; j < 9; j++) { - // A[i, j+1] = (i + 6*j) % 6; - // } - //} - VarHandle i("i", kInt); - VarHandle j("j", kInt); - BufHandle a_buf("A", {6, 10}, kInt); - auto for_j = - For::make(j, -1, 9, Store::make(a_buf, {i, j + 1}, (i + j * 6) % 6)); - auto for_i = For::make(i, 0, 6, for_j); - - const StmtPtr simplified = IRSimplifier::simplify(for_i); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK: for (int j -# CHECK-NOT: A[i, j] = i; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyModWithLoopContext7) { - // Stmt to simplify: - // for (const auto i : c10::irange(6)) { - // for (const auto j : c10::irange(10)) { - // A[i, j] = (i + 6*j) % (-6); - // } - //} - VarHandle i("i", kInt); - VarHandle j("j", kInt); - BufHandle a_buf("A", {6, 10}, kInt); - auto for_j = - For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) % (-6))); - auto for_i = For::make(i, 0, 6, for_j); - - const StmtPtr simplified = IRSimplifier::simplify(for_i); - - std::ostringstream oss; - oss << *(simplified); - const std::string& verification_pattern = - R"IR( -# CHECK: for (int i -# CHECK: for (int j -# CHECK-NOT: A[i, j] = i; - )IR"; - torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); -} - -TEST(Simplify, SimplifyMod) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - - { - // Constant folding works. - ExprHandle body = ExprHandle(10) % 8; - ExprHandle simplified = IRSimplifier::simplify(body); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - IS_IMM_WITH_VAL(Int, simplified.node(), 2); - } - - { - // x % x => 0 - ExprHandle body = x % x; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // 0 % x => 0 - ExprHandle body = ExprHandle(0) % x; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // x % 1 => 0 - ExprHandle body = x % 1; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // Doesn't change unknown mods. - // x % y => x % y - ExprHandle body = x % y; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "x"); - IS_VAR_WITH_NAME(mod->rhs(), "y"); - } - - { - // don't touch if RHS is unknown. - // 4 % x => 4 % x - ExprHandle body = ExprHandle(4) % x; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_IMM_WITH_VAL(Int, mod->lhs(), 4); - IS_VAR_WITH_NAME(mod->rhs(), "x"); - } - - { - // don't touch if LHS is unknown. - // x % 4 => x % 4 - ExprHandle body = x % 4; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "x"); - IS_IMM_WITH_VAL(Int, mod->rhs(), 4); - } - - { - // if LHS is a multiple of RHS, mod is zero. - // 2 * x % x => 0 - ExprHandle body = (x * 2) % x; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // true even if the multiple is not constant. - // x * y % x => 0 - ExprHandle body = (x * y) % x; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // true with multiple unknown values in LHS. - // x * y * z % x => 0 - ExprHandle body = (x * y * z) % x; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // true if the denom is compound. - // x * y * z % y * z => 0 - ExprHandle body = (x * y * z) % (y * z); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // Sanity check true with scalars that are multiples. - // 12 * x % 4 => 0 - ExprHandle body = (x * 12) % 4; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } - - { - // Sanity check not true if the smaller scalar is on LHS. - // 4 * x % 12 => 4 * x % 12 - ExprHandle body = (x * 4) % 12; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_NODE_WITH_NAME(Mul, mod->lhs(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 4); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - IS_IMM_WITH_VAL(Int, mod->rhs(), 12); - } - - { - // Both scalar and symbolic in multiple. - // (6 * x * y) % (3 * x * y) => 0 - ExprHandle body = (ExprHandle(6) * x * y) % (x * y * 3); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 0); - } -} - -// Test that mixing ops together simplifies as expected. -TEST(Simplify, SimplifyMultiOp) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - { - // (x * y) + (x - y) => (x + x * y) - y - ExprHandle body = (x * y) + (x - y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_NODE_WITH_NAME(Add, sub->lhs(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_NODE_WITH_NAME(Mul, add->rhs(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - IS_VAR_WITH_NAME(sub->rhs(), "y"); - } - - { - // (x + y) - x * y => (x + y) - x * y - ExprHandle body = (x + y) - x * y; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_NODE_WITH_NAME(Add, sub->lhs(), add); - IS_NODE_WITH_NAME(Mul, sub->rhs(), mul); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_VAR_WITH_NAME(add->rhs(), "y"); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // (x - y) - (x + y) => -2 * y - ExprHandle body = (x - y) - (x + y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), -2); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // (x - 0) + (x * 1) - (x + 0) => x - ExprHandle body = (x - 0) + (x * 1) - (x + 0); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // (x - 0.f) + (x * 1.f) - (x + 0.f) => float(x) + float(x) - float(x) - // Even in Float simple terms cancel out, but the variable ones cannot. - ExprHandle body = - (x - ExprHandle(0.f)) + (x * ExprHandle(1.f)) - (x + ExprHandle(0.f)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_NODE_WITH_NAME(Add, sub->lhs(), add); - IS_NODE_WITH_NAME(Cast, add->lhs(), cast1); - IS_VAR_WITH_NAME(cast1->src_value(), "x"); - IS_NODE_WITH_NAME(Cast, add->rhs(), cast2); - IS_VAR_WITH_NAME(cast2->src_value(), "x"); - IS_NODE_WITH_NAME(Cast, sub->rhs(), cast3); - IS_VAR_WITH_NAME(cast3->src_value(), "x"); - } -} - -// Test that chaining many ops together works as expected. -TEST(Simplify, SimplifyManyOps) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - { - // x + y + x + x + y + y + x + y + x = 4 * y + 5 * x - ExprHandle body = x + y + x + x + y + y + x + y + x; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Add, simplified.node(), add); - - IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); - IS_IMM_WITH_VAL(Int, lhs->lhs(), 4); - IS_VAR_WITH_NAME(lhs->rhs(), "y"); - - IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); - IS_IMM_WITH_VAL(Int, rhs->lhs(), 5); - IS_VAR_WITH_NAME(rhs->rhs(), "x"); - } - - { - // x - y + x + x - y - y + x - y + x = 5 * x - 4 * y - ExprHandle body = x - y + x + x - y - y + x - y + x; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), add); - - IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); - IS_IMM_WITH_VAL(Int, lhs->lhs(), 5); - IS_VAR_WITH_NAME(lhs->rhs(), "x"); - - IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); - IS_IMM_WITH_VAL(Int, rhs->lhs(), 4); - IS_VAR_WITH_NAME(rhs->rhs(), "y"); - } - - { - // x + y + x - x - y - y + x + y + x = 3 * x - ExprHandle body = x + y + x - x - y - y + x + y + x; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 3); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - } -} - -TEST(Simplify, SimplifyFactorization) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - { - // (2 * x) + (2 * y) => 2 * (x + y) - ExprHandle body = (ExprHandle(2) * x + ExprHandle(2) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - - IS_NODE_WITH_NAME(Add, mul->rhs(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_VAR_WITH_NAME(add->rhs(), "y"); - } - - { - // Factorization when scalars have common divider. - // (2 * x) + (4 * y) => 2 * (2 * y + x) - ExprHandle body = (ExprHandle(2) * x + ExprHandle(4) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - - IS_NODE_WITH_NAME(Add, mul->rhs(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_NODE_WITH_NAME(Mul, add->rhs(), mul2); - IS_IMM_WITH_VAL(Int, mul2->lhs(), 2); - IS_VAR_WITH_NAME(mul2->rhs(), "y"); - } - - { - // Factorization attempt without a common divider. - // (2 * x) + (5 * y) => (5 * y) + (2 * x) - ExprHandle body = (ExprHandle(2) * x + ExprHandle(5) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Add, simplified.node(), add); - - IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); - IS_IMM_WITH_VAL(Int, lhs->lhs(), 2); - IS_VAR_WITH_NAME(lhs->rhs(), "x"); - - IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); - IS_IMM_WITH_VAL(Int, rhs->lhs(), 5); - IS_VAR_WITH_NAME(rhs->rhs(), "y"); - } - - { - // Factorization after merging. - // (2 * x) + (4 * y) + (8 * x + 6 * y) => 10 * (x + y) - ExprHandle body = (ExprHandle(2) * x + ExprHandle(4) * y) + - (ExprHandle(8) * x + ExprHandle(6) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 10); - - IS_NODE_WITH_NAME(Add, mul->rhs(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_VAR_WITH_NAME(add->rhs(), "y"); - } - - { - // Factorization with common divider but different signs. - // (2 * x) + (-4 * y) => 2 * (x - 2 * y) - ExprHandle body = (ExprHandle(2) * x + ExprHandle(-4) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - - IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); - IS_VAR_WITH_NAME(sub->lhs(), "x"); - IS_NODE_WITH_NAME(Mul, sub->rhs(), mul2); - IS_IMM_WITH_VAL(Int, mul2->lhs(), 2); - IS_VAR_WITH_NAME(mul2->rhs(), "y"); - } - - { - // Factorization with all negative numbers. - // (-2 * x) + (-4 * y) => 2 * (-1 * x - 2 * y) - ExprHandle body = ExprHandle(-2) * x + ExprHandle(-4) * y; - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - - IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); - IS_NODE_WITH_NAME(Mul, sub->lhs(), mul2); - IS_IMM_WITH_VAL(Int, mul2->lhs(), -1); - IS_VAR_WITH_NAME(mul2->rhs(), "x"); - IS_NODE_WITH_NAME(Mul, sub->rhs(), mul3); - IS_IMM_WITH_VAL(Int, mul3->lhs(), 2); - IS_VAR_WITH_NAME(mul3->rhs(), "y"); - } - - { - // The following test ensures that there in no infinite recursion during - // factorization when negative numbers are involved. - VarHandle a("a", kInt); - VarHandle b("b", kInt); - VarHandle c("c", kInt); - VarHandle d("d", kInt); - VarHandle e("e", kInt); - VarHandle f("f", kInt); - VarHandle g("g", kInt); - VarHandle h("h", kInt); - - ExprHandle body = a * 1024 + 0 + b * (-1) + c * (-1) + d * 1 + e * 1 + - f * 32 + g * (-1024) + h * (-32); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR( - simplified, - "((((((d + e) + 1024 * a) + 32 * f) - b) - c) - 1024 * g) - 32 * h"); - } -} - -// (4 * x + y + z * 2) + (4 * x + y + z * 4) => 2 * (y + 3 * z + 4 * x) -TEST(Simplify, SimplifyFactorizeUneven) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - ExprHandle body = - (ExprHandle(4) * x + y + z * 2) + (ExprHandle(4) * x + y + z * 4); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), root); - IS_IMM_WITH_VAL(Int, root->lhs(), 2); - IS_NODE_WITH_NAME(Add, root->rhs(), add1); - IS_NODE_WITH_NAME(Add, add1->lhs(), add2); - - IS_VAR_WITH_NAME(add2->lhs(), "y"); - IS_NODE_WITH_NAME(Mul, add2->rhs(), zmul); - IS_NODE_WITH_NAME(Mul, add1->rhs(), xmul); - - IS_IMM_WITH_VAL(Int, xmul->lhs(), 4); - IS_VAR_WITH_NAME(xmul->rhs(), "x"); - - IS_IMM_WITH_VAL(Int, zmul->lhs(), 3); - IS_VAR_WITH_NAME(zmul->rhs(), "z"); -} - -// (x * y) + (2 * x) * (x + y) => 2 * (x * x) + 3 * (x * y) -// This is kind of a placeholder test for variable factorization. -TEST(Simplify, SimplifyDeeperTerms) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = (x * y) + (ExprHandle(2) * x) * (x + y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Add, simplified.node(), add); - - IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); - IS_IMM_WITH_VAL(Int, lhs->lhs(), 2); - IS_NODE_WITH_NAME(Mul, lhs->rhs(), xxTerm); - IS_VAR_WITH_NAME(xxTerm->lhs(), "x"); - IS_VAR_WITH_NAME(xxTerm->rhs(), "x"); - - IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); - IS_IMM_WITH_VAL(Int, rhs->lhs(), 3); - IS_NODE_WITH_NAME(Mul, rhs->rhs(), xyTerm); - IS_VAR_WITH_NAME(xyTerm->lhs(), "x"); - IS_VAR_WITH_NAME(xyTerm->rhs(), "y"); -} - -// Tests the difference between two less trivial expressions. -// (m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n) => 1 -TEST(Simplify, SimplifyDeeperDifference) { - VarHandle n("n", kInt); - VarHandle n_1("n_1", kInt); - VarHandle m("m", kInt); - ExprHandle body = - (m * (ExprHandle(1) * n_1) + (n + 1)) - (m * (ExprHandle(1) * n_1) + n); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_IMM_WITH_VAL(Int, simplified.node(), 1); -} - -// Test constant folding into the difference between expressions. -// 2 + char((m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n)) => 3 -TEST(Simplify, SimplifyFoldComplexDifference) { - VarHandle n("n", kInt); - VarHandle n_1("n_1", kInt); - VarHandle m("m", kInt); - ExprHandle body = - (IntImm::make(2) + - (Cast::make( - kChar, - (m * (ExprHandle(1) * n_1) + (n + 1)) - - (m * (ExprHandle(1) * n_1) + n)))); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 3); -} - -TEST(Simplify, SimplifyIfComponents) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = IfThenElse::make( - ((ExprHandle(5) - ExprHandle(4)) * x) > y, - ExprHandle(2) * x - x, - ExprHandle(2) * y - y); - - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(IfThenElse, simplified.node(), ifexpr); - - IS_NODE_WITH_NAME(CompareSelect, ifexpr->condition(), cmp); - ASSERT_EQ(cmp->compare_select_op(), kGT); - IS_VAR_WITH_NAME(cmp->lhs(), "x"); - IS_VAR_WITH_NAME(cmp->rhs(), "y"); - - IS_VAR_WITH_NAME(ifexpr->true_value(), "x"); - IS_VAR_WITH_NAME(ifexpr->false_value(), "y"); -} - -TEST(Simplify, SimplifyOpaqueTerms) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - { - // 2 * x/y * y - x/y * y => x/y * y - ExprHandle body = ((ExprHandle(2)) * (x / y) * y) - ((x / y) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_NODE_WITH_NAME(Div, mul->lhs(), div); - IS_VAR_WITH_NAME(div->lhs(), "x"); - IS_VAR_WITH_NAME(div->rhs(), "y"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // x%y - (x%y - 1) => 1 - ExprHandle body = (x % y) - ((x % y) - 1); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_IMM_WITH_VAL(Int, simplified.node(), 1); - } -} - -TEST(Simplify, SimplifySymbolicMinMax) { - { - // Minimum with constant difference between terms. - VarHandle x("x", kInt); - ExprHandle body = Min::make(x + 3, x + 7, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_IMM_WITH_VAL(Int, add->rhs(), 3); - } - - { - // Maximum with constant difference between terms. - VarHandle x("x", kInt); - ExprHandle body = Max::make(x + 3, x + 7, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_IMM_WITH_VAL(Int, add->rhs(), 7); - } - - { - // Can't simplify multiples because of signedness of variable component. - // TODO: maybe we could for unsigned types? - VarHandle x("x", kInt); - ExprHandle body = Max::make(x * 3, x * 7, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE(Max, simplified.node()); - } -} - -TEST(Simplify, SimplifyNestedMax) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - - { - // Max(x + y, x + y) => x + y - ExprHandle body = Max::make(x + y, x + y, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - IS_BINOP_W_VARS(Add, simplified.node(), add, "x", "y"); - } - - { - // Max(x + y, Max(x + y, z)) => Max(x + y, z) - ExprHandle body = Max::make(x + y, Max::make(x + y, z, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max); - IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y"); - IS_VAR_WITH_NAME(max->rhs(), "z"); - } - - { - // Max(x + y, Max(z, x + y)) => Max(x + y, z) - ExprHandle body = Max::make(x + y, Max::make(z, x + y, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max); - IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y"); - IS_VAR_WITH_NAME(max->rhs(), "z"); - } - - { - // Max(Max(x + y, z), x + y) => Max(x + y, z) - ExprHandle body = Max::make(Max::make(x + y, z, true), x + y, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max); - IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y"); - IS_VAR_WITH_NAME(max->rhs(), "z"); - } - - { - // Max(Max(z, x + y), x + y) => Max(x + y, z) - ExprHandle body = Max::make(Max::make(z, x + y, true), x + y, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max); - IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y"); - IS_VAR_WITH_NAME(max->rhs(), "z"); - } - - { - // Max(Max(x, y), x) => Max(Max(x, y), x) - // Nested Max ops with different propagate_nans should not be simplified. - ExprHandle body = Max::make(Max::make(x, y, true), x, false); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max); - IS_BINOP_W_VARS(Max, max->lhs(), max1, "x", "y"); - ASSERT_TRUE(max1->propagate_nans()); - IS_VAR_WITH_NAME(max->rhs(), "x"); - ASSERT_FALSE(max->propagate_nans()); - } - - { - // Max(Min(x, y), Min(x, z)) => Min(Max(y, z), x) - ExprHandle body = - Max::make(Min::make(x, y, true), Min::make(x, z, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)"); - } - - { - // Max(Min(x, y), Min(z, x)) => Min(Max(y, z), x) - ExprHandle body = - Max::make(Min::make(x, y, true), Min::make(z, x, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)"); - } - - { - // Max(Min(y, x), Min(x, z)) => Min(Max(y, z), x) - ExprHandle body = - Max::make(Min::make(y, x, true), Min::make(x, z, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)"); - } - - { - // Max(Min(y, x), Min(z, x)) => Min(Max(y, z), x) - ExprHandle body = - Max::make(Min::make(y, x, true), Min::make(z, x, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)"); - } - - { - // Max(Min(y, x), Min(z, x)) => Max(Min(x, y), Min(x, z)) - // When all the ops in the pattern do not have the same propagate_nans, - // it should not be simplified. - ExprHandle body = - Max::make(Min::make(y, x, true), Min::make(z, x, false), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max); - IS_BINOP_W_VARS(Min, max->lhs(), min1, "x", "y"); - ASSERT_TRUE(min1->propagate_nans()); - IS_BINOP_W_VARS(Min, max->rhs(), min2, "x", "z"); - ASSERT_FALSE(min2->propagate_nans()); - ASSERT_TRUE(max->propagate_nans()); - } - - { - // Max(5, Max(x, 8)) => Max(x, 8) - ExprHandle body = Max::make(5, Max::make(x, 8, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8); - ASSERT_TRUE(max->propagate_nans()); - } - - { - // Max(8, Max(x, 5)) => Max(x, 8) - ExprHandle body = Max::make(8, Max::make(x, 5, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8); - ASSERT_TRUE(max->propagate_nans()); - } - - { - // Max(Max(x, 8), 5) => Max(x, 8) - ExprHandle body = Max::make(Max::make(x, 8, true), 5, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8); - ASSERT_TRUE(max->propagate_nans()); - } - - { - // Max(Max(x, 5), 8) => Max(x, 8) - ExprHandle body = Max::make(Max::make(x, 5, true), 8, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8); - ASSERT_TRUE(max->propagate_nans()); - } - - { - // Max(5, Max(x, Max(y, Max(z, 8)))) => Max(Max(Max(x, 8), y), z) - ExprHandle body = Max::make( - 5, Max::make(x, Max::make(y, Max::make(z, 8, true), true), true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max1); - IS_NODE_WITH_NAME(Max, max1->lhs(), max2); - IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); - ASSERT_TRUE(max3->propagate_nans()); - IS_VAR_WITH_NAME(max2->rhs(), "y"); - IS_VAR_WITH_NAME(max1->rhs(), "z"); - } - - { - // Max(8, Max(Max(y, Max(z, 5)), x)) => Max(Max(Max(x, 8), y), z) - ExprHandle body = Max::make( - 8, Max::make(Max::make(y, Max::make(z, 5, true), true), x, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max1); - IS_NODE_WITH_NAME(Max, max1->lhs(), max2); - IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); - ASSERT_TRUE(max3->propagate_nans()); - IS_VAR_WITH_NAME(max2->rhs(), "y"); - IS_VAR_WITH_NAME(max1->rhs(), "z"); - } - - { - // Max(5, Max(Max(Max(z, 8), y), x)) => Max(Max(Max(x, 8), y), z) - ExprHandle body = Max::make( - 5, Max::make(Max::make(Max::make(z, 8, true), y, true), x, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max1); - IS_NODE_WITH_NAME(Max, max1->lhs(), max2); - IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); - ASSERT_TRUE(max3->propagate_nans()); - IS_VAR_WITH_NAME(max2->rhs(), "y"); - IS_VAR_WITH_NAME(max1->rhs(), "z"); - } - - { - // Max(Max(x, Max(y, Max(5, z))), 8) => Max(Max(Max(x, 8), y), z) - ExprHandle body = Max::make( - Max::make(x, Max::make(y, Max::make(5, z, true), true), true), 8, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max1); - IS_NODE_WITH_NAME(Max, max1->lhs(), max2); - IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); - ASSERT_TRUE(max3->propagate_nans()); - IS_VAR_WITH_NAME(max2->rhs(), "y"); - IS_VAR_WITH_NAME(max1->rhs(), "z"); - } - - { - // Max(Max(Max(y, Max(8, z)), x), 5) => Max(Max(Max(x, 8), y), z) - ExprHandle body = Max::make( - Max::make(Max::make(y, Max::make(z, 8, true), true), x, true), 5, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max1); - IS_NODE_WITH_NAME(Max, max1->lhs(), max2); - IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); - ASSERT_TRUE(max3->propagate_nans()); - IS_VAR_WITH_NAME(max2->rhs(), "y"); - IS_VAR_WITH_NAME(max1->rhs(), "z"); - } - - { - // Max(Max(Max(Max(5, z), y), x), 8) => Max(Max(Max(x, 8), y), z) - ExprHandle body = Max::make( - Max::make(Max::make(Max::make(z, 5, true), y, true), x, true), 8, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max1); - IS_NODE_WITH_NAME(Max, max1->lhs(), max2); - IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); - ASSERT_TRUE(max3->propagate_nans()); - IS_VAR_WITH_NAME(max2->rhs(), "y"); - IS_VAR_WITH_NAME(max1->rhs(), "z"); - } - - { - // Max(Max(Max(Max(z, 5), y), x), 8) => Max(Max(x, Max(Max(z, 5), y)), 8) - // Do not simplify when all the Max ops do not have the same - // propagate_nans. - ExprHandle body = Max::make( - Max::make(Max::make(Max::make(z, 5, true), y, false), x, true), - 8, - false); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "Max(Max(Max(Max(z, 5, 1), y, 0), x, 1), 8, 0)"); - } - - { - // Max(8, Max(Max(x, 5), Max(y, z))) => Max(Max(Max(x, 8), y), z) - ExprHandle body = Max::make( - 8, Max::make(Max::make(x, 5, true), Max::make(y, z, true), true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max1); - IS_NODE_WITH_NAME(Max, max1->lhs(), max2); - IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); - ASSERT_TRUE(max3->propagate_nans()); - IS_VAR_WITH_NAME(max2->rhs(), "y"); - IS_VAR_WITH_NAME(max1->rhs(), "z"); - } - - { - // Max(Max(Max(x, 5), Max(y, z)), 8) => Max(Max(Max(x, 8), y), z) - ExprHandle body = Max::make( - Max::make(Max::make(x, 5, true), Max::make(y, z, true), true), 8, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max1); - IS_NODE_WITH_NAME(Max, max1->lhs(), max2); - IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); - ASSERT_TRUE(max3->propagate_nans()); - IS_VAR_WITH_NAME(max2->rhs(), "y"); - IS_VAR_WITH_NAME(max1->rhs(), "z"); - } -} - -TEST(Simplify, SimplifyNestedMin) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - - { - // Min(x + y, x + y) => x + y - ExprHandle body = Min::make(x + y, x + y, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - IS_BINOP_W_VARS(Add, simplified.node(), add, "x", "y"); - } - - { - // Min(x + y, Min(x + y, z)) => Min(x + y, z) - ExprHandle body = Min::make(x + y, Min::make(x + y, z, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min); - IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y"); - IS_VAR_WITH_NAME(min->rhs(), "z"); - } - - { - // Min(x + y, Min(z, x + y)) => Min(x + y, z) - ExprHandle body = Min::make(x + y, Min::make(z, x + y, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min); - IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y"); - IS_VAR_WITH_NAME(min->rhs(), "z"); - } - - { - // Min(Min(x + y, z), x + y) => Min(x + y, z) - ExprHandle body = Min::make(Min::make(x + y, z, true), x + y, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min); - IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y"); - IS_VAR_WITH_NAME(min->rhs(), "z"); - } - - { - // Min(Min(z, x + y), x + y) => Min(x + y, z) - ExprHandle body = Min::make(Min::make(z, x + y, true), x + y, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min); - IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y"); - IS_VAR_WITH_NAME(min->rhs(), "z"); - } - - { - // Min(Min(x, y), x) => Min(Min(x, y), x) - // Nested Min ops with different propagate_nans should not be simplified. - ExprHandle body = Min::make(Min::make(x, y, true), x, false); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min1); - IS_BINOP_W_VARS(Min, min1->lhs(), min2, "x", "y"); - ASSERT_TRUE(min2->propagate_nans()); - IS_VAR_WITH_NAME(min1->rhs(), "x"); - ASSERT_FALSE(min1->propagate_nans()); - } - - { - // Min(Max(x, y), Max(x, z)) => Max(Min(y, z), x) - ExprHandle body = - Min::make(Max::make(x, y, true), Max::make(x, z, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)"); - } - - { - // Min(Max(x, y), Max(z, x)) => Max(Min(y, z), x) - ExprHandle body = - Min::make(Max::make(x, y, true), Max::make(z, x, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)"); - } - - { - // Min(Max(y, x), Max(x, z)) => Max(Min(y, z), x) - ExprHandle body = - Min::make(Max::make(y, x, true), Max::make(x, z, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)"); - } - - { - // Min(Max(y, x), Max(z, x)) => Max(Min(y, z), x) - ExprHandle body = - Min::make(Max::make(y, x, true), Max::make(z, x, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)"); - } - - { - // Min(Max(y, x), Max(z, x)) => Min(Max(x, y), Max(x, z)) - // When all the ops in the pattern do not have the same propagate_nans, - // it should not be simplified. - ExprHandle body = - Min::make(Max::make(y, x, true), Max::make(z, x, false), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min); - IS_BINOP_W_VARS(Max, min->lhs(), max1, "x", "y"); - ASSERT_TRUE(max1->propagate_nans()); - IS_BINOP_W_VARS(Max, min->rhs(), max2, "x", "z"); - ASSERT_FALSE(max2->propagate_nans()); - ASSERT_TRUE(min->propagate_nans()); - } - - { - // Min(5, Min(x, 8)) => Min(x, 8) - ExprHandle body = Min::make(5, Min::make(x, 8, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5); - ASSERT_TRUE(min->propagate_nans()); - } - - { - // Min(8, Min(x, 5)) => Min(x, 8) - ExprHandle body = Min::make(8, Min::make(x, 5, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5); - ASSERT_TRUE(min->propagate_nans()); - } - - { - // Min(Min(x, 8), 5) => Min(x, 8) - ExprHandle body = Min::make(Min::make(x, 8, true), 5, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5); - ASSERT_TRUE(min->propagate_nans()); - } - - { - // Min(Min(x, 5), 8) => Min(x, 8) - ExprHandle body = Min::make(Min::make(x, 5, true), 8, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5); - ASSERT_TRUE(min->propagate_nans()); - } - - { - // Min(5, Min(x, Min(y, Min(z, 8)))) => Min(Min(Min(x, 5), y), z) - ExprHandle body = Min::make( - 5, Min::make(x, Min::make(y, Min::make(z, 8, true), true), true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min1); - IS_NODE_WITH_NAME(Min, min1->lhs(), min2); - IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); - ASSERT_TRUE(min3->propagate_nans()); - IS_VAR_WITH_NAME(min2->rhs(), "y"); - IS_VAR_WITH_NAME(min1->rhs(), "z"); - } - - { - // Min(5, Min(Min(y, Min(z, 8)), x)) => Min(Min(Min(x, 5), y), z) - ExprHandle body = Min::make( - 5, Min::make(Min::make(y, Min::make(z, 8, true), true), x, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min1); - IS_NODE_WITH_NAME(Min, min1->lhs(), min2); - IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); - ASSERT_TRUE(min3->propagate_nans()); - IS_VAR_WITH_NAME(min2->rhs(), "y"); - IS_VAR_WITH_NAME(min1->rhs(), "z"); - } - - { - // Min(5, Min(Min(Min(z, 8), y), x)) => Min(Min(Min(x, 5), y), z) - ExprHandle body = Min::make( - 5, Min::make(Min::make(Min::make(z, 8, true), y, true), x, true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min1); - IS_NODE_WITH_NAME(Min, min1->lhs(), min2); - IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); - ASSERT_TRUE(min3->propagate_nans()); - IS_VAR_WITH_NAME(min2->rhs(), "y"); - IS_VAR_WITH_NAME(min1->rhs(), "z"); - } - - { - // Min(Min(x, Min(y, Min(8, z))), 5) => Min(Min(Min(x, 5), y), z) - ExprHandle body = Min::make( - Min::make(x, Min::make(y, Min::make(8, z, true), true), true), 5, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min1); - IS_NODE_WITH_NAME(Min, min1->lhs(), min2); - IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); - ASSERT_TRUE(min3->propagate_nans()); - IS_VAR_WITH_NAME(min2->rhs(), "y"); - IS_VAR_WITH_NAME(min1->rhs(), "z"); - } - - { - // Min(Min(Min(y, Min(8, z)), x), 5) => Min(Min(Min(x, 5), y), z) - ExprHandle body = Min::make( - Min::make(Min::make(y, Min::make(z, 8, true), true), x, true), 5, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min1); - IS_NODE_WITH_NAME(Min, min1->lhs(), min2); - IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); - ASSERT_TRUE(min3->propagate_nans()); - IS_VAR_WITH_NAME(min2->rhs(), "y"); - IS_VAR_WITH_NAME(min1->rhs(), "z"); - } - - { - // Min(Min(Min(Min(8, z), y), x), 5) => Min(Min(Min(x, 5), y), z) - ExprHandle body = Min::make( - Min::make(Min::make(Min::make(z, 8, true), y, true), x, true), 5, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min1); - IS_NODE_WITH_NAME(Min, min1->lhs(), min2); - IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); - ASSERT_TRUE(min3->propagate_nans()); - IS_VAR_WITH_NAME(min2->rhs(), "y"); - IS_VAR_WITH_NAME(min1->rhs(), "z"); - } - - { - // Min(Min(Min(Min(z, 5), y), x), 8) => Min(Min(Min(Min(z, 5), y), x), 8) - // Do not simplify when all the Min ops do not have the same - // propagate_nans. - ExprHandle body = Min::make( - Min::make(Min::make(Min::make(z, 5, true), y, false), x, true), - 8, - false); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "Min(Min(Min(Min(z, 5, 1), y, 0), x, 1), 8, 0)"); - } - - { - // Min(8, Min(Min(x, 5), Min(y, z))) => Min(Min(Min(x, 5), y), z) - ExprHandle body = Min::make( - 8, Min::make(Min::make(x, 5, true), Min::make(y, z, true), true), true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min1); - IS_NODE_WITH_NAME(Min, min1->lhs(), min2); - IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); - ASSERT_TRUE(min3->propagate_nans()); - IS_VAR_WITH_NAME(min2->rhs(), "y"); - IS_VAR_WITH_NAME(min1->rhs(), "z"); - } - - { - // Min(Min(Min(x, 5), Min(y, z)), 8) => Min(Min(Min(x, 5), y), z) - ExprHandle body = Min::make( - Min::make(Min::make(x, 5, true), Min::make(y, z, true), true), 8, true); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min1); - IS_NODE_WITH_NAME(Min, min1->lhs(), min2); - IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); - ASSERT_TRUE(min3->propagate_nans()); - IS_VAR_WITH_NAME(min2->rhs(), "y"); - IS_VAR_WITH_NAME(min1->rhs(), "z"); - } -} - -TEST(Simplify, SimplifyWontReorderFloat) { - { - // 3 * (3 * x) - 3 * (3 * y) => 9 * (x - y) - // This is an expression we can simplify. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) - - ExprHandle(3) * (ExprHandle(3) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 9); - IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); - IS_VAR_WITH_NAME(sub->lhs(), "x"); - IS_VAR_WITH_NAME(sub->rhs(), "y"); - } - - { - // 3 * (3 * x) - 3 * (3 * y) => 3 * (3 * x) - 3 * (3 * y). - // If the vars are floating point, ops are not associative and we can't - // reorder. - VarHandle x("x", kFloat); - VarHandle y("y", kFloat); - - ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) - - ExprHandle(3) * (ExprHandle(3) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul); - IS_IMM_WITH_VAL(Float, lhsMul->lhs(), 3); - IS_NODE_WITH_NAME(Mul, lhsMul->rhs(), lhsVarMul); - IS_IMM_WITH_VAL(Float, lhsVarMul->lhs(), 3); - IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x"); - - IS_NODE_WITH_NAME(Mul, sub->rhs(), rhsMul); - IS_IMM_WITH_VAL(Float, rhsMul->lhs(), 3); - IS_NODE_WITH_NAME(Mul, rhsMul->rhs(), rhsVarMul); - IS_IMM_WITH_VAL(Float, rhsVarMul->lhs(), 3); - IS_VAR_WITH_NAME(rhsVarMul->rhs(), "y"); - } - - { - // 3 * (3 * x) - 3 * (3 * y) => 3 * (3 * x) - (9 * y). - // We will simplify subexprs if they dont reorder floating point ops. - VarHandle x("x", kDouble); - VarHandle y("y", kInt); - - ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) - - ExprHandle(3) * (ExprHandle(3) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul); - IS_IMM_WITH_VAL(Double, lhsMul->lhs(), 3); - IS_NODE_WITH_NAME(Mul, lhsMul->rhs(), lhsVarMul); - IS_IMM_WITH_VAL(Double, lhsVarMul->lhs(), 3); - IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x"); - - IS_NODE_WITH_NAME_AND_CAST(Mul, sub->rhs(), rhsMul, Double); - IS_IMM_WITH_VAL(Int, rhsMul->lhs(), 9); - IS_VAR_WITH_NAME(rhsMul->rhs(), "y"); - } - - { - // Prevent reordering if FP propagated from dtypes. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - ExprHandle body = ExprHandle(3.f) * (ExprHandle(3) * x) - - ExprHandle(3) * (ExprHandle(3.f) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul); - IS_IMM_WITH_VAL(Float, lhsMul->lhs(), 3); - IS_NODE_WITH_NAME_AND_CAST(Mul, lhsMul->rhs(), lhsVarMul, Float); - IS_IMM_WITH_VAL(Int, lhsVarMul->lhs(), 3); - IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x"); - - IS_NODE_WITH_NAME(Mul, sub->rhs(), rhsMul); - IS_IMM_WITH_VAL(Float, rhsMul->lhs(), 3); - IS_NODE_WITH_NAME(Mul, rhsMul->rhs(), rhsVarMul); - IS_IMM_WITH_VAL(Float, rhsVarMul->lhs(), 3); - IS_NODE_WITH_NAME(Cast, rhsVarMul->rhs(), yCast); - IS_VAR_WITH_NAME(yCast->src_value(), "y"); - } - - { - VarHandle x("x", kFloat); - VarHandle y("y", kFloat); - // x%y - (x%y - 1) => x%y - (x%y - 1). - // We won't reorder opaque ops if they are FP. - ExprHandle body = (x % y) - ((x % y) - 1); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_NODE_WITH_NAME(Mod, sub->lhs(), lhsMod); - IS_VAR_WITH_NAME(lhsMod->lhs(), "x"); - IS_VAR_WITH_NAME(lhsMod->rhs(), "y"); - - IS_NODE_WITH_NAME(Sub, sub->rhs(), rhsSub); - IS_NODE_WITH_NAME(Mod, rhsSub->lhs(), rhsMod); - IS_VAR_WITH_NAME(rhsMod->lhs(), "x"); - IS_VAR_WITH_NAME(rhsMod->rhs(), "y"); - IS_IMM_WITH_VAL(Float, rhsSub->rhs(), 1); - } -} - -TEST(Simplify, SimplifyRoundModPattern) { - { - // (x/y)*y + x%y => x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = ((x / y) * y) + (x % y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // Reverse order. - // x%y + (x/y)*y => x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = (x % y) + ((x / y) * y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // Non opaque denominator. - // (x / (4+y)) * (4+y)) + (x % (y + 4)) => x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = ((x / (ExprHandle(4) + y)) * (ExprHandle(4) + y)) + - (x % (y + ExprHandle(4))); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // Reverse order. - // (x % (y + 4)) + (x / (4+y)) * (4+y)) => x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = (x % (y + ExprHandle(4))) + - ((x / (ExprHandle(4) + y)) * (ExprHandle(4) + y)); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // Opaque denominator. - // (x / (2/y)) * (2/y)) + (x % (2/y)) => x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = ((x / (ExprHandle(2) / y)) * (ExprHandle(2) / y)) + - (x % (ExprHandle(2) / y)); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // Non opaque numerator - // ((2*x)/y * y) + ((2*x) % y) => 2 * x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = - (((ExprHandle(2) * x) / y) * y) + ((ExprHandle(2) * x) % y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - } - - { - // Opaque numerator. - // ((x/2) / y * y) + (x/2 % y) => x / 2. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = - (((x / ExprHandle(2)) / y) * y) + ((x / ExprHandle(2)) % y); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Div, simplified.node(), div); - IS_VAR_WITH_NAME(div->lhs(), "x"); - IS_IMM_WITH_VAL(Int, div->rhs(), 2); - } - - { - // Numerator and denominator. - // ((2*x)/(2*y) * (2*y)) + ((2*x) % (2*y)) => 2 * x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = - (((ExprHandle(2) * x) / (ExprHandle(2) * y)) * (ExprHandle(2) * y)) + - ((ExprHandle(2) * x) % (ExprHandle(2) * y)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - } - - { - // Reverse order. - // ((2*x) % (2*y)) + ((2*x)/(2*y) * (2*y)) => 2 * x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = ((ExprHandle(2) * x) % (ExprHandle(2) * y)) + - (((ExprHandle(2) * x) / (ExprHandle(2) * y)) * (ExprHandle(2) * y)); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - } - - { - // Negated Subtraction of Round Mod. - // (x/y) * y - (0 - x%y) => x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = ((x / y) * y) - (ExprHandle(0) - (x % y)); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // Other terms are preserved. - // (x/y)*y + x%y + (y * x) => x + (y * x). - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = ((x / y) * y) + (x % y) + (y * x); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_NODE_WITH_NAME(Mul, add->rhs(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // Sanity checking we won't do the optimization on floats. - VarHandle x("x", kFloat); - VarHandle y("y", kFloat); - ExprHandle body = ((x / y) * y) + (x % y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_NODE_WITH_NAME(Mul, add->lhs(), roundMul); - IS_NODE_WITH_NAME(Div, roundMul->lhs(), roundDiv); - IS_VAR_WITH_NAME(roundDiv->lhs(), "x"); - IS_VAR_WITH_NAME(roundDiv->rhs(), "y"); - IS_VAR_WITH_NAME(roundMul->rhs(), "y"); - IS_NODE_WITH_NAME(Mod, add->rhs(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "x"); - IS_VAR_WITH_NAME(mod->rhs(), "y"); - } - - { - // Sanity check we won't do it if the mod term doesn't match. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - ExprHandle body = ((x / y) * y) + (x % z); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "(x / y) * y + x % z"); - } - - { - // Sanity check we won't do it if the div term doesn't match. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - ExprHandle body = (y * (x / z)) + (x % y); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "x % y + (x / z) * y"); - } - - { - // Sanity check we won't do it if the mul term doesn't match. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - ExprHandle body = ((x / y) * z) + (x % y); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "x % y + (x / y) * z"); - } -} - -TEST(Simplify, SimplifyRoundModPatternFactorization) { - { - // Full factorization. - // 2 * (x/y * y) + 2 * (x%y) => 2 * x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = ExprHandle(2) * ((x / y) * y) + ExprHandle(2) * (x % y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - } - - { - // Partial Factorization. - // 32 * (x/8) + 4 * (x % 8) => 4 * x. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers) - ExprHandle body = ExprHandle(32) * (x / 8) + ExprHandle(4) * (x % 8); - ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 4); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - } - - { - // Factorization requiring constant folding. - // 20 * (x / (16 / 2)) * 2 + (11 % 6) * (x % (7+1)) => 5 * x. - VarHandle x("x", kInt); - ExprHandle body = ExprHandle(40) * (x / (ExprHandle(16) / 2)) + - (ExprHandle(11) % 6) * (x % (ExprHandle(7) + 1)); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 5); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - } - - { - VarHandle x("x", kInt); - ExprHandle body = (x / 5) * 10 + ExprHandle(2) * (x % 5); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - } - - { - VarHandle x("x", kInt); - ExprHandle body = (x / 10) * 0 + x % 5; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "x"); - IS_IMM_WITH_VAL(Int, mod->rhs(), 5); - } -} - -TEST(Simplify, SimplifyRoundModPatternMultivar) { - { - // Multivar. - // (x/8) * 8 + (y/5)*5 + x%8 + y%5 => x + y. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = (x / ExprHandle(8) * ExprHandle(8)) + - (y / ExprHandle(5) * ExprHandle(5)) + (x % 8) + (y % 5); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_VAR_WITH_NAME(add->lhs(), "x"); - IS_VAR_WITH_NAME(add->rhs(), "y"); - } - - { - // Find the right var. - // (y/8) * 8 x%8 + y%8 + z%8 => x%8 + y + z%8 - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - ExprHandle body = - (y / ExprHandle(8) * ExprHandle(8)) + (x % 8) + (y % 8) + (z % 8); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_NODE_WITH_NAME(Add, add->lhs(), add2); - IS_NODE_WITH_NAME(Mod, add2->lhs(), xMod); - IS_VAR_WITH_NAME(xMod->lhs(), "x"); - IS_IMM_WITH_VAL(Int, xMod->rhs(), 8); - IS_VAR_WITH_NAME(add2->rhs(), "y"); - IS_NODE_WITH_NAME(Mod, add->rhs(), zMod); - IS_VAR_WITH_NAME(zMod->lhs(), "z"); - IS_IMM_WITH_VAL(Int, zMod->rhs(), 8); - } - - { - // Compound. - // (x + (z + 512 * y) % 16) + 16 * ((z + 512 * y) / 16) - // => (z + 512 * y) + x - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle z("z", kInt); - - ExprHandle body = x + (z + y * 512) % 16 + ((z + y * 512) / 16 * 16); - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "x + (z + 512 * y)"); - } -} - -TEST(Simplify, SimplifyModRoundModPattern) { - { - // t/7 % 9 * 7 + t % 7 => t%63 - VarHandle t("t", kInt); - ExprHandle body = (t / 7 % 9) * 7 + t % 7; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "t"); - IS_IMM_WITH_VAL(Int, mod->rhs(), 63); - } - - { - // 2*t/7 % 9 * 7 + 2*t % 7 => 2*t % 63 - VarHandle t("t", kInt); - ExprHandle body = (ExprHandle(2) * t / 7 % 9) * 7 + ExprHandle(2) * t % 7; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_NODE_WITH_NAME(Mul, mod->lhs(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_VAR_WITH_NAME(mul->rhs(), "t"); - IS_IMM_WITH_VAL(Int, mod->rhs(), 63); - } - - { - // t/x % y * x + t % x => t%(x*y) - VarHandle t("t", kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = (t / x % y) * x + t % x; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "t"); - IS_NODE_WITH_NAME(Mul, mod->rhs(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // k*t/x % y * x + k*t % x => k*t%(x*y) - VarHandle t("t", kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle k("k", kInt); - ExprHandle body = (k * t / x % y) * x + k * t % x; - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "(k * t) % (x * y)"); - } - - { - // t/k/x % y * x + t/k % x => t/k%(x*y) - VarHandle t("t", kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle k("k", kInt); - ExprHandle body = (t / k / x % y) * x + t / k % x; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_NODE_WITH_NAME(Div, mod->lhs(), div); - IS_VAR_WITH_NAME(div->lhs(), "t"); - IS_VAR_WITH_NAME(div->rhs(), "k"); - IS_NODE_WITH_NAME(Mul, mod->rhs(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // Sanity checking we won't do the optimization on floats. - VarHandle x("x", kFloat); - VarHandle y("y", kFloat); - VarHandle z("z", kFloat); - ExprHandle body = ((x / y % z) * y) + (x % y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_NODE_WITH_NAME(Mul, add->lhs(), mul); - IS_NODE_WITH_NAME(Mod, mul->lhs(), mod); - IS_NODE_WITH_NAME(Div, mod->lhs(), div); - IS_VAR_WITH_NAME(div->lhs(), "x"); - IS_VAR_WITH_NAME(div->rhs(), "y"); - IS_VAR_WITH_NAME(mod->rhs(), "z"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - IS_NODE_WITH_NAME(Mod, add->rhs(), mod2); - IS_VAR_WITH_NAME(mod2->lhs(), "x"); - IS_VAR_WITH_NAME(mod2->rhs(), "y"); - } -} - -TEST(Simplify, SimplifyModRoundModPatternFactorization) { - { - // 2 * (t /7 % 9 * 7) + 2 * (t % 7) => 2 * (t % 63) - VarHandle t("t", kInt); - ExprHandle body = - ExprHandle(2) * ((t / 7 % 9) * 7) + ExprHandle(2) * (t % 7); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_NODE_WITH_NAME(Mod, mul->rhs(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "t"); - IS_IMM_WITH_VAL(Int, mod->rhs(), 63); - } - - { - // t /7 % 9 * 14 + 2* (t % 7) => 2* (t % 63) - VarHandle t("t", kInt); - ExprHandle body = (t / 7 % 9) * 14 + ExprHandle(2) * (t % 7); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_NODE_WITH_NAME(Mod, mul->rhs(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "t"); - IS_IMM_WITH_VAL(Int, mod->rhs(), 63); - } - - { - // t/14 % 9 * 7 + t/2 % 7 => t/2 % 63 - VarHandle t("t", kInt); - ExprHandle body = (t / 14 % 9) * 7 + t / 2 % 7; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_NODE_WITH_NAME(Div, mod->lhs(), div); - IS_VAR_WITH_NAME(div->lhs(), "t"); - IS_IMM_WITH_VAL(Int, div->rhs(), 2); - IS_IMM_WITH_VAL(Int, mod->rhs(), 63); - } - - { - // t/(7*3) % 9 * 7*3 + t % (7*3) => t % 189 - VarHandle t("t", kInt); - ExprHandle body = (t / (ExprHandle(7) * ExprHandle(3)) % 9) * 7 * 3 + - t % (ExprHandle(7) * ExprHandle(3)); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "t"); - IS_IMM_WITH_VAL(Int, mod->rhs(), 189); - } - - { - // 2*(t/x % y * x) + 2*(t % x) => 2*(t%(x*y)) - VarHandle t("t", kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = - ExprHandle(2) * ((t / x % y) * x) + ExprHandle(2) * (t % x); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_NODE_WITH_NAME(Mod, mul->rhs(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "t"); - IS_NODE_WITH_NAME(Mul, mod->rhs(), mul2); - IS_VAR_WITH_NAME(mul2->lhs(), "x"); - IS_VAR_WITH_NAME(mul2->rhs(), "y"); - } -} - -TEST(Simplify, SimplifyModRoundModPatternMultivar) { - { - // t/7 % 9 * 7 + t % 7 + t => t % 63 + t - VarHandle t("t", kInt); - ExprHandle body = (t / 7 % 9) * 7 + t % 7 + t; - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "t % 63 + t"); - } - - { - // t/7 % 9 * 7 + t/8 % 9 * 8 + t % 7 + t % 8 => t % 63 + t % 72 - VarHandle t("t", kInt); - ExprHandle body = (t / 7 % 9) * 7 + (t / 8 % 9) * 8 + t % 7 + t % 8; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_NODE_WITH_NAME(Mod, add->lhs(), mod1); - IS_VAR_WITH_NAME(mod1->lhs(), "t"); - IS_IMM_WITH_VAL(Int, mod1->rhs(), 63); - IS_NODE_WITH_NAME(Mod, add->rhs(), mod2); - IS_VAR_WITH_NAME(mod2->lhs(), "t"); - IS_IMM_WITH_VAL(Int, mod2->rhs(), 72); - } - - { - // k + t/x % y * x + t % x => k + t%(x*y) - VarHandle t("t", kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle k("k", kInt); - ExprHandle body = k + (t / x % y) * x + t % x; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_VAR_WITH_NAME(add->lhs(), "k"); - IS_NODE_WITH_NAME(Mod, add->rhs(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "t"); - IS_NODE_WITH_NAME(Mul, mod->rhs(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - } - - { - // t/x % y * x + t % x + (t/k / x % y) * x + t/k % x - // => t%(x*y) + t/k % (x*y) - VarHandle t("t", kInt); - VarHandle x("x", kInt); - VarHandle y("y", kInt); - VarHandle k("k", kInt); - ExprHandle body = (t / x % y) * x + t % x + (t / k / x % y) * x + t / k % x; - ExprHandle simplified = IRSimplifier::simplify(body); - checkExprIR(simplified, "(t / k) % (x * y) + t % (x * y)"); - } - - { - // 3D: (7 * ((i0_flat / 7) % 9) + i0_flat % 7) + 63 * (i0_flat / 63) - // => io_flat - VarHandle t("io_flat", kInt); - ExprHandle body = - ExprHandle(7) * (t / 7 % 9) + t % 7 + ExprHandle(63) * (t / 63); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "io_flat"); - } - - { // 5D: i0_flat / (11 * 10 * 9 * 7) * (7 * 9 * 10 * 11) + - // (i0_flat / (10 * 9 * 7) % 11) * 7 * 9 * 10 + - // (i0_flat / (9 * 7) % 10) * 7 * 9 + - // (i0_flat / 7 % 9) * 7 + - // i0_flat % 7 => io_flat - VarHandle t("io_flat", kInt); - ExprHandle body = (t / (ExprHandle(11) * 10 * 9 * 7)) * (7 * 9 * 10 * 11) + - (t / (ExprHandle(10) * 9 * 7) % 11) * 7 * 9 * 10 + - (t / (ExprHandle(9) * 7) % 10) * 7 * 9 + (t / 7 % 9) * 7 + t % 7; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "io_flat"); - } - - { - // 3D: (m * ((i0_flat / m) % n) + i0_flat % m) + (m * n) * - // (i0_flat / (m * n)) => io_flat - VarHandle t("io_flat", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - ExprHandle body = m * (t / m % n) + t % m + (m * n) * (t / (m * n)); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "io_flat"); - } - - { // 5D: i0_flat / (k * l * n * m) * (m * n * l * k) + - // (i0_flat / (l * n * m) % k) * m * n * l + - // (i0_flat / (n * m) % l) * m * n + - // (i0_flat / m % n) * m + - // i0_flat % m => io_flat - VarHandle t("io_flat", kInt); - VarHandle m("m", kInt); - VarHandle n("n", kInt); - VarHandle l("l", kInt); - VarHandle k("k", kInt); - ExprHandle body = (t / (k * l * n * m)) * (m * n * l * k) + - (t / (l * n * m) % k) * m * n * l + (t / (n * m) % l) * m * n + - (t / m % n) * m + t % m; - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "io_flat"); - } -} - -TEST(Simplify, SimplifyDivisionScalarFactorization) { - { - // Simple factorization of numerator and denominator. - // 8x / 4y => 2x / y. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = (x * 8) / (y * 4); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Div, simplified.node(), div); - IS_NODE_WITH_NAME(Mul, div->lhs(), lhs); - IS_IMM_WITH_VAL(Int, lhs->lhs(), 2); - IS_VAR_WITH_NAME(lhs->rhs(), "x"); - IS_VAR_WITH_NAME(div->rhs(), "y"); - } - - { - // Don't change anything if we can't factorize. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = (x * 7) / (y * 4); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Div, simplified.node(), div); - IS_NODE_WITH_NAME(Mul, div->lhs(), lhs); - IS_IMM_WITH_VAL(Int, lhs->lhs(), 7); - IS_VAR_WITH_NAME(lhs->rhs(), "x"); - IS_NODE_WITH_NAME(Mul, div->rhs(), rhs); - IS_IMM_WITH_VAL(Int, rhs->lhs(), 4); - IS_VAR_WITH_NAME(rhs->rhs(), "y"); - } - - { - // Don't reorder floats. - VarHandle x("x", kFloat); - VarHandle y("y", kFloat); - ExprHandle body = (x * 8) / (y * 4); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Div, simplified.node(), div); - IS_NODE_WITH_NAME(Mul, div->lhs(), lhs); - IS_VAR_WITH_NAME(lhs->lhs(), "x"); - IS_IMM_WITH_VAL(Float, lhs->rhs(), 8.f); - IS_NODE_WITH_NAME(Mul, div->rhs(), rhs); - IS_VAR_WITH_NAME(rhs->lhs(), "y"); - IS_IMM_WITH_VAL(Float, rhs->rhs(), 4.f); - } - - { - // Sanity check we do nothing if there are only scalar parts. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = (x * 1) / (y * 1); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Div, simplified.node(), div); - IS_VAR_WITH_NAME(div->lhs(), "x"); - IS_VAR_WITH_NAME(div->rhs(), "y"); - } - - { - // Can factorize amounts of variables. - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = (x + x + x + x) / (y + y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Div, simplified.node(), div); - IS_NODE_WITH_NAME(Mul, div->lhs(), lhs); - IS_IMM_WITH_VAL(Int, lhs->lhs(), 2); - IS_VAR_WITH_NAME(lhs->rhs(), "x"); - IS_VAR_WITH_NAME(div->rhs(), "y"); - } -} - -TEST(Simplify, SimplifyConstantBranches) { - { - // If the condition is constant true then take the true_value. - // 1 ? x : y => x - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle t(1); - ExprHandle body = IfThenElse::make(t, x, y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // If the condition is constant false then take the false_value. - // 0 ? x : y => y - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle t(0); - ExprHandle body = IfThenElse::make(t, x, y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "y"); - } - - { - // condition is simplified before checking. - // (x-x) ? x : y => y - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = IfThenElse::make(x - x, x, y); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "y"); - } - - { - // If both branches are the same then don't do the condition. - // y ? x : x => x - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = IfThenElse::make(y, x, x); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_VAR_WITH_NAME(simplified.node(), "x"); - } - - { - // If both branches simplify to the same thing it still works. - // y ? (x + x) : (2 * x) => x - VarHandle x("x", kInt); - VarHandle y("y", kInt); - ExprHandle body = IfThenElse::make(y, x + x, ExprHandle(2) * x); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 2); - IS_VAR_WITH_NAME(mul->rhs(), "x"); - } -} - -TEST(Simplify, SimplifyConstantCond) { - { - // If the condition is constant true then take the true_value. - // 1 ? A[0] = 1 : B[0] = 1 => A[0] = 1 - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - ExprHandle condition(1); - StmtPtr true_val = Store::make(a, {0}, 1); - StmtPtr false_val = Store::make(b, {0}, 1); - - CondPtr body = alloc(condition.node(), true_val, false_val); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "A"); - } - - { - // If the condition is constant false then take the false_value. - // 0 ? A[0] = 1 : B[0] = 1 => B[0] = 1 - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - ExprHandle condition(0); - StmtPtr true_val = Store::make(a, {0}, 1); - StmtPtr false_val = Store::make(b, {0}, 1); - - StmtPtr body = alloc(condition.node(), true_val, false_val); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "B"); - } - - { - // condition is simplified before checking. - // (x-x) ? A[0] = 1 : B[0] = 1 => B[0] = 1 - VarHandle x("x", kInt); - BufHandle a("A", {1}, kInt); - BufHandle b("B", {1}, kInt); - ExprHandle condition(x - x); - StmtPtr true_val = Store::make(a, {0}, 1); - StmtPtr false_val = Store::make(b, {0}, 1); - - StmtPtr body = alloc(condition.node(), true_val, false_val); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "B"); - } - - { - // If both branches are the same then don't do the condition. - // x ? A[0] = x : A[0] = x => A[0] = x - VarHandle x("x", kInt); - BufHandle a("A", {1}, kInt); - ExprHandle condition(x - x); - StmtPtr true_val = Store::make(a, {0}, x); - StmtPtr false_val = Store::make(a, {0}, x); - - StmtPtr body = alloc(condition.node(), true_val, false_val); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "A"); - } - - { - // If both branches simplify to the same thing it still works. - // x ? (x + x) : (2 * x) => x - VarHandle x("x", kInt); - BufHandle a("A", {1}, kInt); - ExprHandle condition(x - x); - StmtPtr true_val = Store::make(a, {0}, ExprHandle(2) * x); - StmtPtr false_val = Store::make(a, {0}, x + x); - - StmtPtr body = alloc(condition.node(), true_val, false_val); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "A"); - } - - { - // But not if they dont - // x ? x : (2 * x) => x ? x : (2 * x) - VarHandle x("x", kInt); - BufHandle a("A", {1}, kInt); - ExprHandle condition(x); - StmtPtr true_val = Store::make(a, {0}, x); - StmtPtr false_val = Store::make(a, {0}, ExprHandle(2) * x); - - StmtPtr body = alloc(condition.node(), true_val, false_val); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - ASSERT_EQ(block, nullptr); - } - - { - StmtPtr cond = alloc( - ExprHandle(false).node(), - alloc(std::vector({})), - nullptr); - StmtPtr simplified = IRSimplifier::simplify(cond); - ASSERT_EQ(simplified, nullptr); - } - - { - StmtPtr cond = alloc( - ExprHandle(true).node(), - nullptr, - alloc(std::vector({}))); - StmtPtr simplified = IRSimplifier::simplify(cond); - ASSERT_EQ(simplified, nullptr); - } -} - -TEST(Simplify, SimplifyEliminateEmptyCond) { - // If the branches are empty in different ways, eliminate. - { - VarHandle x("x", kInt); - ExprHandle condition(x); - StmtPtr true_val = alloc(std::vector({})); - - StmtPtr body = alloc(condition.node(), true_val, nullptr); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - ASSERT_NE(block, nullptr); - ASSERT_EQ(block->nstmts(), 0); - } - - { - VarHandle x("x", kInt); - ExprHandle condition(x); - StmtPtr false_val = alloc(std::vector({})); - - StmtPtr body = alloc(condition.node(), nullptr, false_val); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - ASSERT_NE(block, nullptr); - ASSERT_EQ(block->nstmts(), 0); - } -} - -TEST(Simplify, SimplifyConstantComparisons) { - auto ComparisonTest = - [](ExprHandle a, ExprHandle b, CompareSelectOperation op, int result) { - ExprHandle body = CompareSelect::make(a, b, op); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), result); - }; - - // Equals. - ComparisonTest(2, 2, kEQ, 1); - ComparisonTest(1, 2, kEQ, 0); - ComparisonTest(2, 1, kEQ, 0); - - // Greater than. - ComparisonTest(2, 2, kGT, 0); - ComparisonTest(1, 2, kGT, 0); - ComparisonTest(2, 1, kGT, 1); - - // Greater or Equal. - ComparisonTest(2, 2, kGE, 1); - ComparisonTest(1, 2, kGE, 0); - ComparisonTest(2, 1, kGE, 1); - - // Less Than. - ComparisonTest(2, 2, kLT, 0); - ComparisonTest(1, 2, kLT, 1); - ComparisonTest(2, 1, kLT, 0); - - // Less or Equal. - ComparisonTest(2, 2, kLE, 1); - ComparisonTest(1, 2, kLE, 1); - ComparisonTest(2, 1, kLE, 0); - - // Not equal. - ComparisonTest(2, 2, kNE, 0); - ComparisonTest(1, 2, kNE, 1); - ComparisonTest(2, 1, kNE, 1); - - // With specified results: - ExprHandle body = CompareSelect::make(2, 2, 5, 42, kNE); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_IMM_WITH_VAL(Int, simplified.node(), 42); -} - -TEST(Simplify, SimplifySymbolicComparisons) { - VarHandle x("x", kInt); - VarHandle y("y", kInt); - - auto TookTrueBranch = [](ExprHandle a) { IS_IMM_WITH_VAL(Int, a.node(), 1); }; - auto TookFalseBranch = [](ExprHandle a) { - IS_IMM_WITH_VAL(Int, a.node(), 0); - }; - - // EQ - - // x == x => 1 - ExprHandle body = CompareSelect::make(x, x, kEQ); - TookTrueBranch(IRSimplifier::simplify(body)); - - // x == x+1 => 0 - body = CompareSelect::make(x, x + 1, kEQ); - TookFalseBranch(IRSimplifier::simplify(body)); - - // x == x * 2 cannot simplify since we don't know x is nonzero. - body = CompareSelect::make(x, x * 2, kEQ); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - IS_NODE(CompareSelect, IRSimplifier::simplify(body).node()); - - // x == x * 1 => 1 - body = CompareSelect::make(x, x * 1, kEQ); - TookTrueBranch(IRSimplifier::simplify(body)); - - { - // x == y => x == y - body = CompareSelect::make(x, y, kEQ); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(CompareSelect, simplified.node(), cmp); - ASSERT_EQ(cmp->compare_select_op(), kEQ); - IS_VAR_WITH_NAME(cmp->lhs(), "x"); - IS_VAR_WITH_NAME(cmp->rhs(), "y"); - } - - { - // x == 5 => x == 5 - body = CompareSelect::make(x, 5, kEQ); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(CompareSelect, simplified.node(), cmp); - ASSERT_EQ(cmp->compare_select_op(), kEQ); - IS_VAR_WITH_NAME(cmp->lhs(), "x"); - IS_IMM_WITH_VAL(Int, cmp->rhs(), 5); - } - - // GT - - // x+1 > x => 1 - body = CompareSelect::make(x + 1, x, kGT); - TookTrueBranch(IRSimplifier::simplify(body)); - - // x > x + 1 => 0 - body = CompareSelect::make(x, x + 1, kGT); - TookFalseBranch(IRSimplifier::simplify(body)); - - // x > x - 1 => 1 - body = CompareSelect::make(x, x - 1, kGT); - TookTrueBranch(IRSimplifier::simplify(body)); - - // x - 1 > x => 0 - body = CompareSelect::make(x - 1, x, kGT); - TookFalseBranch(IRSimplifier::simplify(body)); - - // x > x => 0 - body = CompareSelect::make(x, x, kGT); - TookFalseBranch(IRSimplifier::simplify(body)); - - // x * 2 > x => x * 2 > x - // since we don't know the sign of x. - body = CompareSelect::make(x * 2, x, kGT); - IS_NODE(CompareSelect, IRSimplifier::simplify(body).node()); - - // GE - - // x+1 >= x => 1 - body = CompareSelect::make(x + 1, x, kGE); - TookTrueBranch(IRSimplifier::simplify(body)); - - // x >= x + 1 => 0 - body = CompareSelect::make(x, x + 1, kGE); - TookFalseBranch(IRSimplifier::simplify(body)); - - // x >= x => 1 - body = CompareSelect::make(x, x, kGE); - TookTrueBranch(IRSimplifier::simplify(body)); - - // x * 2 >= x => x * 2 >= x - // since we don't know the sign of x. - body = CompareSelect::make(x * 2, x, kGE); - IS_NODE(CompareSelect, IRSimplifier::simplify(body).node()); - - // LT - - // x+1 < x => 0 - body = CompareSelect::make(x + 1, x, kLT); - TookFalseBranch(IRSimplifier::simplify(body)); - - // x < x + 1 => 1 - body = CompareSelect::make(x, x + 1, kLT); - TookTrueBranch(IRSimplifier::simplify(body)); - - // x < x => 0 - body = CompareSelect::make(x, x, kLT); - TookFalseBranch(IRSimplifier::simplify(body)); - - // LE - - // x+1 <= x => 0 - body = CompareSelect::make(x + 1, x, kLE); - TookFalseBranch(IRSimplifier::simplify(body)); - - // x <= x + 1 => 1 - body = CompareSelect::make(x, x + 1, kLE); - TookTrueBranch(IRSimplifier::simplify(body)); - - // x <= x => 1 - body = CompareSelect::make(x, x, kLE); - TookTrueBranch(IRSimplifier::simplify(body)); - - // NE - - // x+1 != x => 1 - body = CompareSelect::make(x + 1, x, kNE); - TookTrueBranch(IRSimplifier::simplify(body)); - - // x != x + 1 => 1 - body = CompareSelect::make(x, x + 1, kNE); - TookTrueBranch(IRSimplifier::simplify(body)); - - // x != x => 0 - body = CompareSelect::make(x, x, kNE); - TookFalseBranch(IRSimplifier::simplify(body)); -} - -TEST(Simplify, SimplifyEliminateZeroLengthFor) { - { - // Will eliminate zero loop For. - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - auto body = For::make(i, 0, 0, Store::make(c, {i}, Load::make(a, {i}))); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - ASSERT_EQ(block->nstmts(), 0); - } - - { - // still works if start is not zero. - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - auto body = For::make(i, 2, 2, Store::make(c, {i}, Load::make(a, {i}))); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - ASSERT_EQ(block->nstmts(), 0); - } - - { - // works if both terms are variable. - VarHandle x("x", kInt); - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - auto body = For::make(i, x, x, Store::make(c, {i}, Load::make(a, {i}))); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - ASSERT_EQ(block->nstmts(), 0); - } - - { - // works if one term simplifies down. - VarHandle x("x", kInt); - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - auto body = For::make(i, 0, x - x, Store::make(c, {i}, Load::make(a, {i}))); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - ASSERT_EQ(block->nstmts(), 0); - } - - { - // Sanity check does nothing if the condition is not met. - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - auto body = For::make(i, 0, 3, Store::make(c, {i}, Load::make(a, {i}))); - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE(For, simplified); - } -} - -TEST(Simplify, SimplifyOneLoopFor) { - { - // Will remove the loop if the body is run once. - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i}))); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "C"); - IS_IMM_WITH_VAL(Int, store->flat_index(), 0); - } - - { - // still works if start is not zero. - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - auto body = For::make(i, 2, 3, Store::make(c, {i}, Load::make(a, {i}))); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "C"); - IS_IMM_WITH_VAL(Int, store->flat_index(), 2); - } - - { - // works if both terms are variable. - VarHandle x("x", kInt); - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - auto body = For::make(i, x, x + 1, Store::make(c, {i}, Load::make(a, {i}))); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "C"); - IS_VAR_WITH_NAME(store->flat_index(), "x"); - } - - { - // works if one term simplifies down. - VarHandle x("x", kInt); - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - auto body = - For::make(i, 0, x - x + 1, Store::make(c, {i}, Load::make(a, {i}))); - StmtPtr simplified = IRSimplifier::simplify(body); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "C"); - IS_IMM_WITH_VAL(Int, store->flat_index(), 0); - } - - { - // Sanity check does nothing if the condition is not met. - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - auto body = For::make(i, 0, 3, Store::make(c, {i}, Load::make(a, {i}))); - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE(For, simplified); - } -} - -TEST(Simplify, SimplifyForWontLoseLoopOptions) { - { - // Sanity check does nothing if the condition is not met. - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - LoopOptions options; - options.set_gpu_block_index(LoopOptions::IDX_W); - auto body = - For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i})), options); - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(For, simplified, for_); - LoopOptions options2 = for_->loop_options(); - ASSERT_EQ(options.gpu_block_index(), options2.gpu_block_index()); - } -} - -TEST(Simplify, SimplifyMultilevelFor) { - { - // Multiple layers of For will be simplified out. - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i}))); - auto outer = For::make(j, 0, 1, body); - StmtPtr simplified = IRSimplifier::simplify(outer); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "C"); - IS_IMM_WITH_VAL(Int, store->flat_index(), 0); - } - - { - // Will maintain an outer loop if the inner loop is eliminated. - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i}))); - auto outer = For::make(j, 0, 2, body); - StmtPtr simplified = IRSimplifier::simplify(outer); - ForPtr for__ = static_to(simplified); - IS_NODE_WITH_NAME(For, for__, for_); - IS_VAR_WITH_NAME(for_->var(), "j"); - IS_IMM_WITH_VAL(Int, for_->start(), 0); - IS_IMM_WITH_VAL(Int, for_->stop(), 2); - BlockPtr block = to(for_->body()); - ASSERT_NE(block, nullptr); - IS_NODE_WITH_NAME(Store, block->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "C"); - IS_IMM_WITH_VAL(Int, store->flat_index(), 0); - } - - { - // Will maintain inner loop if outer loops is eliminated. - BufHandle a("A", {4}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto body = For::make(i, 0, 2, Store::make(c, {i}, Load::make(a, {i}))); - auto outer = For::make(j, 0, 1, body); - StmtPtr simplified = IRSimplifier::simplify(outer); - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(For, block->front(), for_); - IS_VAR_WITH_NAME(for_->var(), "i"); - IS_IMM_WITH_VAL(Int, for_->start(), 0); - IS_IMM_WITH_VAL(Int, for_->stop(), 2); - IS_NODE_WITH_NAME(Store, for_->body()->front(), store); - IS_VAR_WITH_NAME(store->base_handle(), "C"); - IS_VAR_WITH_NAME(store->flat_index(), "i"); - } -} - -TEST(Simplify, SimplifyForCleansUp) { - { - BufHandle a("a", {1, 12, 1}, kFloat); - VarHandle x("x", kInt); - Tensor b = Compute( - "x", - {1, 12, 1}, - [](const VarHandle& i, const VarHandle& m, const VarHandle& n) { - return i + m + n; - }); - LoopNest l({b}); - l.prepareForCodegen(); - - StmtPtr body = LoopNest::sanitizeNames(l.root_stmt()); - StmtPtr simplified = IRSimplifier::simplify(body); - - BlockPtr block = to(simplified); - IS_NODE_WITH_NAME(For, block->front(), for_); - // for is over "m". - IS_VAR_WITH_NAME(for_->var(), "j"); - // x[m] = m; - IS_NODE_WITH_NAME(Store, for_->body()->front(), store); - IS_VAR_WITH_NAME(store->flat_index(), "j"); - IS_VAR_WITH_NAME(store->value(), "j"); - } -} - -TEST(Simplify, SimplifyEliminateEmptyFor) { - { - // Flatten many layers around an empty block to an empty block. - StmtPtr last = alloc(std::vector({})); - for ([[maybe_unused]] const auto i : c10::irange(11)) { - VarHandle loopVar("loopVar", kInt); - last = For::make(loopVar, 0, 10, last); - } - - StmtPtr simplified = IRSimplifier::simplify(last); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 0); - } -} - -TEST(Simplify, SimplifyFlattenBlock) { - { - // Flatten multiple blocks down to one. - // { { { stmt1, stmt2 } } } => { stmt1, stmt2 } - BufHandle a("A", {1}, kInt); - StorePtr store1 = Store::make(a, {0}, 1); - StorePtr store2 = Store::make(a, {0}, 0); - - BlockPtr block1 = alloc(std::vector({store1, store2})); - BlockPtr block2 = alloc(std::vector({block1})); - - BlockPtr enclosing = alloc(std::vector({block2})); - StmtPtr simplified = IRSimplifier::simplify(enclosing); - - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 2); - - IS_NODE_WITH_NAME(Store, block->front(), store1_); - IS_NODE_WITH_NAME(Store, block->back(), store2_); - - ASSERT_EQ(store1->value(), store1_->value()); - ASSERT_EQ(store2->value(), store2_->value()); - } - - { - // Flatten multiple sub blocks containing statements. - // { { stmt1 }, { stmt2 } } => { stmt1, stmt2 } - BufHandle a("A", {1}, kInt); - StorePtr store1 = Store::make(a, {0}, 1); - StorePtr store2 = Store::make(a, {0}, 0); - - BlockPtr block1 = alloc(std::vector({store1})); - BlockPtr block2 = alloc(std::vector({store2})); - - BlockPtr enclosing = alloc(std::vector({block1, block2})); - StmtPtr simplified = IRSimplifier::simplify(enclosing); - - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 2); - - IS_NODE_WITH_NAME(Store, block->front(), store1_); - IS_NODE_WITH_NAME(Store, block->back(), store2_); - - ASSERT_EQ(store1->value(), store1_->value()); - ASSERT_EQ(store2->value(), store2_->value()); - } - - { - // Flatten sub blocks with different depths. - // { stmt1 , { { stmt2 } } } => { stmt1, stmt2 } - BufHandle a("A", {1}, kInt); - StorePtr store1 = Store::make(a, {0}, 1); - StorePtr store2 = Store::make(a, {0}, 0); - - BlockPtr block1 = alloc(std::vector({store2})); - BlockPtr block2 = alloc(std::vector({block1})); - - BlockPtr enclosing = alloc(std::vector({store1, block2})); - StmtPtr simplified = IRSimplifier::simplify(enclosing); - - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 2); - - IS_NODE_WITH_NAME(Store, block->front(), store1_); - IS_NODE_WITH_NAME(Store, block->back(), store2_); - - ASSERT_EQ(store1->value(), store1_->value()); - ASSERT_EQ(store2->value(), store2_->value()); - } - - { - // Flatten many layers around an empty block to an empty block. - StmtPtr last = alloc(std::vector({})); - for ([[maybe_unused]] const auto i : c10::irange(11)) { - last = alloc(std::vector({last})); - } - - StmtPtr simplified = IRSimplifier::simplify(last); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 0); - } -} - -TEST(Simplify, SimplifyEliminateZeroLengthAlloc) { - { - // Simple positive case. - BufHandle b("x", {0}, kInt); - - AllocatePtr alloc_ = Allocate::make(b); - FreePtr free_ = Free::make(b); - - BlockPtr block1 = alloc(std::vector({alloc_, free_})); - ASSERT_EQ(block1->nstmts(), 2); - - StmtPtr simplified = IRSimplifier::simplify(block1); - IS_NODE_WITH_NAME(Block, simplified, block2); - ASSERT_EQ(block2->nstmts(), 0); - } - - { - // Simple negative case. - BufHandle b("x", {2}, kInt); - - AllocatePtr alloc_ = Allocate::make(b); - FreePtr free_ = Free::make(b); - - BlockPtr block1 = alloc(std::vector({alloc_, free_})); - ASSERT_EQ(block1->nstmts(), 2); - - StmtPtr simplified = IRSimplifier::simplify(block1); - IS_NODE_WITH_NAME(Block, simplified, block2); - ASSERT_EQ(block2->nstmts(), 2); - } - - { - // Finds right Alloc/Free. - BufHandle b1("x", {0}, kInt); - BufHandle b2("y", {2}, kInt); - - AllocatePtr alloc1 = Allocate::make(b1); - AllocatePtr alloc2 = Allocate::make(b2); - FreePtr free2_ = Free::make(b2); - FreePtr free1_ = Free::make(b1); - - BlockPtr block1 = - alloc(std::vector({alloc1, alloc2, free2_, free1_})); - ASSERT_EQ(block1->nstmts(), 4); - - StmtPtr simplified = IRSimplifier::simplify(block1); - IS_NODE_WITH_NAME(Block, simplified, block2); - ASSERT_EQ(block2->nstmts(), 2); - IS_NODE_WITH_NAME(Allocate, block2->stmts().front(), simplified_alloc); - IS_VAR_WITH_NAME(simplified_alloc->buffer_var(), "y"); - IS_NODE_WITH_NAME(Free, block2->stmts().back(), simplified_free); - ASSERT_EQ(simplified_alloc->buffer_var(), simplified_free->buffer_var()); - } - - { - // Dynamic shape. - VarHandle z("z", kInt); - BufHandle b1("x", {0}, kInt); - BufHandle b2("y", {z}, kInt); - - AllocatePtr alloc1 = Allocate::make(b1); - AllocatePtr alloc2 = Allocate::make(b2); - FreePtr free2_ = Free::make(b2); - FreePtr free1_ = Free::make(b1); - - BlockPtr block1 = - alloc(std::vector({alloc1, alloc2, free2_, free1_})); - ASSERT_EQ(block1->nstmts(), 4); - StmtPtr simplified = IRSimplifier::simplify(block1); - IS_NODE_WITH_NAME(Block, simplified, block2); - ASSERT_EQ(block2->nstmts(), 2); - } -} - -TEST(Simplify, DontSimplifyRand) { - { - // rand() + rand() = rand() + rand() NOT 2 * rand(). - ExprHandle body = - Intrinsics::make(kRand, kInt) + Intrinsics::make(kRand, kInt); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_RAND(add->lhs()); - IS_RAND(add->rhs()); - } - - { - // rand() - rand() = rand() - rand() NOT 0. - ExprHandle body = - Intrinsics::make(kRand, kFloat) - Intrinsics::make(kRand, kFloat); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_RAND(sub->lhs()); - IS_RAND(sub->rhs()); - } - - { - // rand() * rand() = rand() * rand(). - ExprHandle body = - Intrinsics::make(kRand, kInt) * Intrinsics::make(kRand, kInt); - ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_RAND(mul->lhs()); - IS_RAND(mul->rhs()); - } -} - -TEST(Simplify, SimplifyReorderForCond) { - BufHandle a("A", {4}, kInt); - BufHandle b("B", {1}, kInt); - BufHandle c("C", {4}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - - { - // for ( if ( ... ) ) => if ( for ( ... ) ). - auto body = For::make( - i, - 0, - 4, - Cond::make( - CompareSelect::make(j, 10, CompareSelectOperation::kLT), - Store::make(c, {i}, Load::make(a, {i})), - nullptr)); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Cond, simplified, cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); - IS_NODE_WITH_NAME(For, true_block->front(), loop); - } - - { - // Can't reorder if condition is dependent on the loop var. - auto body = For::make( - i, - 0, - 4, - Cond::make( - CompareSelect::make(i, 2, CompareSelectOperation::kEQ), - Store::make(c, {i}, Load::make(a, {i})), - nullptr)); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(For, simplified, loop); - IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond); - } - - { - // Can't reorder if condition is dependent on a var that is modified inside - // the loop. - auto body = For::make( - i, - 0, - 4, - Cond::make( - CompareSelect::make( - Load::make(c, {0}), 10, CompareSelectOperation::kLT), - Store::make(c, {0}, Load::make(a, {i})), - nullptr)); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(For, simplified, loop); - IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond); - } - - { - // Condition based on buffer not referenced in body. Can reorder here. - auto body = For::make( - i, - 0, - 4, - Cond::make( - CompareSelect::make( - Load::make(b, {0}), 10, CompareSelectOperation::kLT), - Store::make(c, {0}, Load::make(a, {i})), - nullptr)); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Cond, simplified, cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); - IS_NODE_WITH_NAME(For, true_block->front(), loop); - } - - { - // Condition based on buffer read only in body. Can reorder here. - auto body = For::make( - i, - 0, - 4, - Cond::make( - CompareSelect::make( - Load::make(a, {0}), 10, CompareSelectOperation::kLT), - Store::make(c, {0}, Load::make(a, {i})), - nullptr)); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Cond, simplified, cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); - IS_NODE_WITH_NAME(For, true_block->front(), loop); - } - - { - // Condition depends on Let in the loop. Cannot reorder. - auto body = For::make( - i, - 0, - 4, - Block::make( - {Let::make(j, 3), - Cond::make( - CompareSelect::make(j, 10, CompareSelectOperation::kLT), - Store::make(c, {0}, Load::make(a, {i})), - nullptr)})); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(For, simplified, loop); - IS_NODE_WITH_NAME(Let, loop->body()->front(), let); - IS_NODE_WITH_NAME(Cond, loop->body()->back(), cond); - } - - { - // Multi level Ifs where all conditions are distinct. Move BOTH Cond - // statements outside the loop. - auto body = For::make( - i, - 0, - 4, - Cond::make( - CompareSelect::make( - Load::make(a, {0}), 10, CompareSelectOperation::kLT), - Cond::make( - CompareSelect::make(j, 10, CompareSelectOperation::kEQ), - Store::make(c, {0}, Load::make(a, {i})), - nullptr), - nullptr)); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Cond, simplified, cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); - IS_NODE_WITH_NAME(Cond, true_block->front(), cond2); - IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_block2); - IS_NODE_WITH_NAME(For, true_block2->front(), loop); - } - - { - // Multi level Ifs where the inner condition does depend on a loop var, - // reorder only the first Cond. - auto body = For::make( - i, - 0, - 4, - Cond::make( - CompareSelect::make( - Load::make(a, {0}), 10, CompareSelectOperation::kLT), - Cond::make( - CompareSelect::make(i, 3, CompareSelectOperation::kEQ), - Store::make(c, {0}, Load::make(a, {i})), - nullptr), - nullptr)); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Cond, simplified, cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); - IS_NODE_WITH_NAME(For, true_block->front(), loop); - IS_NODE_WITH_NAME(Block, loop->body(), loop_body); - IS_NODE_WITH_NAME(Cond, loop_body->front(), cond2); - } - - { - // Don't reorder if there's an else block of the Cond. - // We could, but is it much better? - auto body = For::make( - i, - 0, - 4, - Cond::make( - CompareSelect::make(j, 10, CompareSelectOperation::kLT), - Store::make(c, {0}, Load::make(a, {i})), - Store::make(c, {0}, 0))); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(For, simplified, loop); - IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond); - } - - { - // Condition uses distinct region of Tensor. - // We could reorder here with better analysis, but we don't. Included for - // completeness. - auto body = For::make( - i, - 0, - 4, - Cond::make( - CompareSelect::make( - Load::make(c, {0}), 10, CompareSelectOperation::kLT), - Store::make(c, {1}, Load::make(a, {i})), - nullptr)); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(For, simplified, loop); - IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond); - } -} - -TEST(Simplify, SimplifyFuseConditions) { - BufHandle a("A", {2}, kInt); - BufHandle b("B", {2}, kInt); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - - { - // Can fuse since the conditions are identical. - // if (A) { X }; if (A) { Y }; => if (A) { X; Y } - auto body = Block::make( - {Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, i), - nullptr), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {1}, i), - nullptr)}); - - StmtPtr simplified = IRSimplifier::simplify(body); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 1); - IS_NODE_WITH_NAME(Cond, block->front(), cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); - ASSERT_EQ(true_stmt->nstmts(), 2); - ASSERT_EQ(cond->false_stmt(), nullptr); - } - - { - // Can't fuse, conditions are not identical in lhs (i != j). - auto body = Block::make( - {Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, i), - nullptr), - Cond::make( - CompareSelect::make(j, 10, CompareSelectOperation::kLT), - Store::make(a, {1}, i), - nullptr)}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 2); - IS_NODE_WITH_NAME(Cond, block->front(), cond1); - IS_NODE_WITH_NAME(Cond, block->back(), cond2); - - IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1); - IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2); - ASSERT_EQ(true_stmt1->nstmts(), 1); - ASSERT_EQ(true_stmt2->nstmts(), 1); - - ASSERT_EQ(cond1->false_stmt(), nullptr); - ASSERT_EQ(cond2->false_stmt(), nullptr); - } - { - // Can't fuse, conditions are not identical in rhs (10 != 11). - auto body = Block::make( - {Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, i), - nullptr), - Cond::make( - CompareSelect::make(i, 11, CompareSelectOperation::kLT), - Store::make(a, {1}, i), - nullptr)}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 2); - IS_NODE_WITH_NAME(Cond, block->front(), cond1); - IS_NODE_WITH_NAME(Cond, block->back(), cond2); - - IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1); - IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2); - ASSERT_EQ(true_stmt1->nstmts(), 1); - ASSERT_EQ(true_stmt2->nstmts(), 1); - - ASSERT_EQ(cond1->false_stmt(), nullptr); - ASSERT_EQ(cond2->false_stmt(), nullptr); - } - - { - // Can't fuse, conditions are not identical in operation (LT vs GT). - auto body = Block::make( - {Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, i), - nullptr), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kGT), - Store::make(a, {1}, i), - nullptr)}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 2); - IS_NODE_WITH_NAME(Cond, block->front(), cond1); - IS_NODE_WITH_NAME(Cond, block->back(), cond2); - - IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1); - IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2); - ASSERT_EQ(true_stmt1->nstmts(), 1); - ASSERT_EQ(true_stmt2->nstmts(), 1); - - ASSERT_EQ(cond1->false_stmt(), nullptr); - ASSERT_EQ(cond2->false_stmt(), nullptr); - } - - { - // Can't fuse, CompareSelect results are different. - // Actually we totally could if we normalized CompareSelect results, but - // TODO for later. - auto body = Block::make( - {Cond::make( - CompareSelect::make(i, 10, 1, 0, CompareSelectOperation::kLT), - Store::make(a, {0}, i), - nullptr), - Cond::make( - CompareSelect::make(j, 10, 2, 0, CompareSelectOperation::kLT), - Store::make(a, {1}, i), - nullptr)}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 2); - IS_NODE_WITH_NAME(Cond, block->front(), cond1); - IS_NODE_WITH_NAME(Cond, block->back(), cond2); - - IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1); - IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2); - ASSERT_EQ(true_stmt1->nstmts(), 1); - ASSERT_EQ(true_stmt2->nstmts(), 1); - - ASSERT_EQ(cond1->false_stmt(), nullptr); - ASSERT_EQ(cond2->false_stmt(), nullptr); - } - - { - // Can fuse with false stmt only. - auto body = Block::make( - {Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - nullptr, - Store::make(a, {0}, i)), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - nullptr, - Store::make(a, {1}, i))}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 1); - IS_NODE_WITH_NAME(Cond, block->front(), cond); - IS_NODE_WITH_NAME(Block, cond->false_stmt(), false_stmt); - ASSERT_EQ(false_stmt->nstmts(), 2); - ASSERT_EQ(cond->true_stmt(), nullptr); - } - - { - // Can fuse with both true and false stmt. - auto body = Block::make( - {Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, i), - Store::make(b, {0}, i)), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {1}, i), - Store::make(b, {1}, i))}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 1); - IS_NODE_WITH_NAME(Cond, block->front(), cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); - ASSERT_EQ(true_stmt->nstmts(), 2); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), false_stmt); - ASSERT_EQ(false_stmt->nstmts(), 2); - } - - { - // Can fuse with mismatched true / false stmt existing - auto body = Block::make( - {Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, i), - nullptr), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - nullptr, - Store::make(b, {1}, i))}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 1); - IS_NODE_WITH_NAME(Cond, block->front(), cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); - ASSERT_EQ(true_stmt->nstmts(), 1); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), false_stmt); - ASSERT_EQ(false_stmt->nstmts(), 1); - } - - { - // Can fuse partial block contents, ie when there are non fused stmts before - // and after. - // before: - // if (j < 10) { A[0] = j; } - // if (i < 10) { A[0] = i; } - // if (i < 10) { A[1] = i; } - // if (i < 11) { A[1] = j; } - // - // after: - // - // if (j < 10) { A[0] = j; } - // if (i < 10) { - // A[0] = i; - // A[1] = i; - // } - // if (i < 11) { A[1] = j; } - - auto body = Block::make({ - Cond::make( - CompareSelect::make(j, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, j), - nullptr), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, i), - nullptr), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {1}, i), - nullptr), - Cond::make( - CompareSelect::make(i, 11, CompareSelectOperation::kLT), - Store::make(a, {1}, j), - nullptr), - }); - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 3); - auto it = block->begin(); - it++; - IS_NODE_WITH_NAME(Cond, *it, cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); - ASSERT_EQ(true_stmt->nstmts(), 2); - ASSERT_EQ(cond->false_stmt(), nullptr); - } - - { - // Can fuse longer sequences of identical conditions. - auto body = Block::make({ - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, j), - nullptr), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, i), - nullptr), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {1}, i), - nullptr), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {1}, j), - nullptr), - }); - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 1); - IS_NODE_WITH_NAME(Cond, block->front(), cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); - ASSERT_EQ(true_stmt->nstmts(), 4); - ASSERT_EQ(cond->false_stmt(), nullptr); - } - - { - // Can't fuse through a non condition. - auto body = Block::make({ - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, j), - nullptr), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {0}, i), - nullptr), - Store::make(b, {1}, i + j), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {1}, i), - nullptr), - Cond::make( - CompareSelect::make(i, 10, CompareSelectOperation::kLT), - Store::make(a, {1}, j), - nullptr), - }); - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 3); - IS_NODE_WITH_NAME(Cond, block->front(), cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); - ASSERT_EQ(true_stmt->nstmts(), 2); - ASSERT_EQ(cond->false_stmt(), nullptr); - - IS_NODE_WITH_NAME(Cond, block->back(), cond2); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt2); - ASSERT_EQ(true_stmt2->nstmts(), 2); - ASSERT_EQ(cond2->false_stmt(), nullptr); - - auto it = block->begin(); - it++; - IS_NODE_WITH_NAME(Store, *it, middle); - } - - { - // Can fuse if the conditions simplify to the same thing. - auto body = Block::make( - {Cond::make( - CompareSelect::make( - i * 2, - ExprHandle(87) % ExprHandle(11), - CompareSelectOperation::kLT), - Store::make(a, {0}, i), - nullptr), - Cond::make( - CompareSelect::make( - i * 2, - ExprHandle(300) / ExprHandle(30), - CompareSelectOperation::kLT), - Store::make(a, {1}, i), - nullptr)}); - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 1); - IS_NODE_WITH_NAME(Cond, block->front(), cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); - ASSERT_EQ(true_stmt->nstmts(), 2); - ASSERT_EQ(cond->false_stmt(), nullptr); - } - - { - // Can fuse non-CompareSelects. - // if (i) { X } if (i) { Y } => if (i) { X; Y } - auto body = Block::make( - {Cond::make(i, Store::make(a, {0}, i), nullptr), - Cond::make(i, Store::make(a, {1}, i), nullptr)}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 1); - IS_NODE_WITH_NAME(Cond, block->front(), cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); - ASSERT_EQ(true_stmt->nstmts(), 2); - ASSERT_EQ(cond->false_stmt(), nullptr); - } - - { - // Sanity check won't fuse different non-CompareSelects. - auto body = Block::make( - {Cond::make(i, Store::make(a, {0}, i), nullptr), - Cond::make(j, Store::make(a, {1}, i), nullptr)}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 2); - IS_NODE_WITH_NAME(Cond, block->front(), cond1); - IS_NODE_WITH_NAME(Cond, block->back(), cond2); - } - - { - // Sanity check constant condition elimination still occurs when merging is - // possible. - auto body = Block::make( - {Cond::make(1, Store::make(a, {0}, i), nullptr), - Cond::make(1, Store::make(a, {1}, i), nullptr)}); - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 2); - IS_NODE_WITH_NAME(Store, block->front(), store1); - IS_NODE_WITH_NAME(Store, block->back(), store2); - } - - { - // Sanity check for-cond reordering occurs after fusing. - auto body = For::make( - i, - 0, - 4, - Block::make( - {Cond::make( - CompareSelect::make(j, 10, CompareSelectOperation::kLT), - Store::make(a, {1}, Load::make(b, {0})), - nullptr), - Cond::make( - CompareSelect::make(j, 10, CompareSelectOperation::kLT), - Store::make(a, {2}, Load::make(b, {0})), - nullptr)})); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Cond, simplified, cond); - IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); - IS_NODE_WITH_NAME(For, true_block->front(), loop); - } -} - -TEST(Simplify, SimplifySyncThreads) { - BufHandle a("A", {4}, kInt); - VarHandle i("i", kInt); - - { - // Merge two inner SyncThreads. - auto body = Block::make( - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - {Store::make(a, {0}, 1), - alloc(), - alloc(), - Store::make(a, {1}, 0)}); - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 3); - auto it = block->begin(); - IS_NODE(Store, *it++); - IS_NODE(SyncThreads, *it++); - IS_NODE(Store, *it++); - } - - { - // Eliminate outer SyncThreads. - auto body = Block::make( - {alloc(), Store::make(a, {1}, 0), alloc()}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 1); - auto it = block->begin(); - IS_NODE(Store, *it); - } - - { - // Merge many inner SyncThreads. - auto body = Block::make( - {Store::make(a, {0}, 1), - alloc(), - alloc(), - alloc(), - alloc(), - alloc(), - Store::make(a, {1}, 0)}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 3); - auto it = block->begin(); - IS_NODE(Store, *it++); - IS_NODE(SyncThreads, *it++); - IS_NODE(Store, *it++); - } - - { - // Merge multiple outer SyncThreads. - auto body = Block::make( - {alloc(), - alloc(), - Store::make(a, {1}, 0), - alloc(), - alloc(), - alloc(), - alloc()}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 1); - auto it = block->begin(); - IS_NODE(Store, *it); - } - - { - // Merge multiple sections; - auto body = Block::make( - {Store::make(a, {0}, 1), - alloc(), - alloc(), - Store::make(a, {1}, 0), - Store::make(a, {2}, 0), - alloc(), - alloc(), - alloc(), - Store::make(a, {3}, 0)}); - - StmtPtr simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Block, simplified, block); - ASSERT_EQ(block->nstmts(), 6); - auto it = block->begin(); - IS_NODE(Store, *it++); - IS_NODE(SyncThreads, *it++); - IS_NODE(Store, *it++); - IS_NODE(Store, *it++); - IS_NODE(SyncThreads, *it++); - IS_NODE(Store, *it++); - } -} - -TEST(Simplify, SimplifyRampSubBroadcast) { - int num_lanes = 4; - ExprHandle ramp = Ramp::make(ExprHandle(0), ExprHandle(6), num_lanes); - ExprHandle broadcast = Broadcast::make(ExprHandle(-5), num_lanes); - ExprHandle simplified = IRSimplifier::simplify(ramp - broadcast); - RampPtr newRamp = simplified.AsNode(); - IS_NODE_WITH_NAME(IntImm, newRamp->base(), base); - ASSERT_EQ(base->value(), 5); - IS_NODE_WITH_NAME(IntImm, newRamp->stride(), stride); - ASSERT_EQ(stride->value(), 6); - ASSERT_EQ(newRamp->lanes(), num_lanes); -} - -TEST(Simplify, SimplifyBroadcastTermExpander) { - int num_lanes = 8; - ExprHandle bc0 = Broadcast::make(ExprHandle(0), num_lanes); - ExprHandle bc1 = Broadcast::make(ExprHandle(1), num_lanes); - ExprHandle bc2 = Broadcast::make(ExprHandle(2), num_lanes); - // NB: We need a term in the middle which isn't simplified to trigger the - // relevant path in TermExpander::mutate. The two bc1 terms are brought - // together and simplified to 2 * bc1, which then needs to make 2 multi-lane. - ExprHandle simplified = IRSimplifier::simplify(bc1 + (bc0 / bc2) + bc1); - BufHandle buf("buf", {num_lanes}, kInt); - // The result isn't fully simplified currently and thus would be brittle to - // match. Observe its value instead. - auto store = Store::make(buf, {Ramp::make(0, 1, num_lanes)}, simplified); - SimpleIREvaluator eval(store, {buf}); - std::vector output(num_lanes); - eval(output); - for (const auto i : c10::irange(num_lanes)) { - ASSERT_EQ(output[i], 2); - } -} - -TEST(Simplify, CompareSelectLoopBounds) { - constexpr int N = 8; - BufHandle b("b", {N}, kFloat); - VarHandle n("n", kInt); - VarHandle m("m", kInt); - VarHandle var_N("var_N", kInt); - VarHandle var_M("var_M", kInt); - - auto test_case_fn = [](const VarHandle& n, - const BufHandle& b, - const ExprHandle& start, - const ExprHandle& stop, - const int& cmp_val, - const CompareSelectOperation& cmp_op, - const std::string& check_string) { - StmtPtr s = For::make( - n, - start, - stop, - b.store({n}, CompareSelect::make(n, cmp_val, 0.f, 1.0f, cmp_op))); - s = IRSimplifier::simplify(s); - std::ostringstream oss; - oss << *s; - std::string target_string = "# CHECK: "; - target_string += check_string; - torch::jit::testing::FileCheck().run(target_string, oss.str()); - }; - - auto test_case_nest_loops_fn = [](const VarHandle& n, - const VarHandle& m, - const BufHandle& b, - const ExprHandle& n_start, - const ExprHandle& n_stop, - const ExprHandle& m_start, - const ExprHandle& m_stop, - const CompareSelectOperation& cmp_op, - const std::string& check_string) { - StmtPtr s = For::make( - m, - m_start, - m_stop, - b.store({n, m}, CompareSelect::make(n, m, 0.f, 1.0f, cmp_op))); - StmtPtr root_s = For::make(n, n_start, n_stop, s); - root_s = IRSimplifier::simplify(root_s); - std::ostringstream oss; - oss << *root_s; - std::string target_string = "# CHECK: "; - target_string += check_string; - torch::jit::testing::FileCheck().run(target_string, oss.str()); - }; - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n < 1 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 1.f; - // } - test_case_fn(n, b, 1, N, 1, kLT, "b[n] = 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n <= 1 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n <= 1 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, 1, kLE, "b[n] = n<=1 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n <= 0 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 1.f; - // } - test_case_fn(n, b, 1, N, 0, kLE, "b[n] = 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n < 0 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 1.f; - // } - test_case_fn(n, b, 1, N, 0, kLT, "b[n] = 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n < 8 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 0.f; - // } - test_case_fn(n, b, 1, N, N, kLT, "b[n] = 0.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n <= 7 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 0.f; - // } - test_case_fn(n, b, 1, N, N - 1, kLE, "b[n] = 0.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n <= 8 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 0.f; - // } - test_case_fn(n, b, 1, N, N, kLE, "b[n] = 0.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n < 7 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n < 7 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, N - 1, kLT, "b[n] = n<7 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n > 0 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 0.f; - // } - test_case_fn(n, b, 1, N, 0, kGT, "b[n] = 0.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n > 1 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n > 1 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, 1, kGT, "b[n] = n>1 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n >= 1 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 0.f; - // } - test_case_fn(n, b, 1, N, 1, kGE, "b[n] = 0.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n > 7 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 1.f; - // } - test_case_fn(n, b, 1, N, N - 1, kGT, "b[n] = 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n >= 7 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n >= 7 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, N - 1, kGE, "b[n] = n>=7 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n > 5 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n > 5 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, 5, kGT, "b[n] = n>5 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n >= 5 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n >= 5 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, 5, kGE, "b[n] = n>=5 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n > 8 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 1.f; - // } - test_case_fn(n, b, 1, N, N, kGT, "b[n] = 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n >= 8 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 1.f; - // } - test_case_fn(n, b, 1, N, N, kGE, "b[n] = 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, 2)) { - // b[n] = n == 1 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, 2)) { - // b[1] = 0.f; - // } - test_case_fn(n, b, 1, 2, 1, kEQ, "b[1] = 0.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n == 1 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n == 1 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, 1, kEQ, "b[n] = n==1 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n == 0 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 1.f; - // } - test_case_fn(n, b, 1, N, 0, kEQ, "b[n] = 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n == 7 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n == 7 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, N - 1, kEQ, "b[n] = n==7 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n == 8 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 1.f; - // } - test_case_fn(n, b, 1, N, N, kEQ, "b[n] = 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n != 1 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n != 1 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, 1, kNE, "b[n] = n!=1 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n != 7 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n != 7 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, N - 1, kNE, "b[n] = n!=7 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n != 5 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n != 5 ? 0.f : 1.f; - // } - test_case_fn(n, b, 1, N, 5, kNE, "b[n] = n!=5 ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n != 0 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 0.f; - // } - test_case_fn(n, b, 1, N, 0, kNE, "b[n] = 0.f;"); - - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n != 8 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 0.f; - // } - test_case_fn(n, b, 1, N, N, kNE, "b[n] = 0.f;"); - - // Before: - // for (const auto n : c10::irange(10, 20)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = (n != m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(10, 20)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = 0.f; - // } - // } - test_case_nest_loops_fn(n, m, b, 10, 20, 30, 40, kNE, "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 20, - var_N + 30, - var_N + 40, - kNE, - "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 20, - var_M + 30, - var_M + 40, - kNE, - "b[n, m] = n!=m ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(30, 40)) { - // for(const auto m : c10::irange(10, 20)) { - // b[n, m] = (n != m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(30, 40)) { - // for(const auto m : c10::irange(10, 20)) { - // b[n, m] = 0.f; - // } - // } - test_case_nest_loops_fn(n, m, b, 30, 40, 10, 20, kNE, "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 30, - var_N + 40, - var_N + 10, - var_N + 20, - kNE, - "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 30, - var_N + 40, - var_M + 10, - var_M + 20, - kNE, - "b[n, m] = n!=m ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(30, 40)) { - // for(const auto m : c10::irange(10, 31)) { - // b[n, m] = (n != m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(30, 40)) { - // for(const auto m : c10::irange(10, 31)) { - // b[n, m] = (n != m) ? 0.f : 1.f; - // } - // } - test_case_nest_loops_fn( - n, m, b, 30, 40, 10, 31, kNE, "b[n, m] = n!=m ? 0.f : 1.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 30, - var_N + 40, - var_N + 10, - var_N + 31, - kNE, - "b[n, m] = n!=m ? 0.f : 1.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 30, - var_N + 40, - var_M + 10, - var_M + 31, - kNE, - "b[n, m] = n!=m ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(10, 31)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = (n != m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(10, 31)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = (n != m) ? 0.f : 1.f; - // } - // } - test_case_nest_loops_fn( - n, m, b, 10, 31, 30, 40, kNE, "b[n, m] = n!=m ? 0.f : 1.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 31, - var_N + 30, - var_N + 40, - kNE, - "b[n, m] = n!=m ? 0.f : 1.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 31, - var_M + 30, - var_M + 40, - kNE, - "b[n, m] = n!=m ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(10, 20)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = (n < m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(10, 20)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = 0.f; - // } - // } - test_case_nest_loops_fn(n, m, b, 10, 20, 30, 40, kLT, "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 20, - var_N + 30, - var_N + 40, - kLT, - "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 20, - var_M + 30, - var_M + 40, - kLT, - "b[n, m] = n m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(30, 40)) { - // for(const auto m : c10::irange(10, 20)) { - // b[n, m] = 0.f; - // } - // } - test_case_nest_loops_fn(n, m, b, 30, 40, 10, 20, kGT, "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 30, - var_N + 40, - var_N + 10, - var_N + 20, - kGT, - "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 30, - var_N + 40, - var_M + 10, - var_M + 20, - kGT, - "b[n, m] = n>m ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(10, 31)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = (n > m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(10, 31)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = 1.f; - // } - // } - test_case_nest_loops_fn(n, m, b, 10, 31, 30, 40, kGT, "b[n, m] = 1.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 31, - var_N + 30, - var_N + 40, - kGT, - "b[n, m] = 1.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 31, - var_M + 30, - var_M + 40, - kGT, - "b[n, m] = n>m ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(30, 40)) { - // for(const auto m : c10::irange(10, 31)) { - // b[n, m] = (n >= m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(30, 40)) { - // for(const auto m : c10::irange(10, 31)) { - // b[n, m] = 0.f; - // } - // } - test_case_nest_loops_fn(n, m, b, 30, 40, 10, 31, kGE, "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 30, - var_N + 40, - var_N + 10, - var_N + 31, - kGE, - "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 30, - var_N + 40, - var_M + 10, - var_M + 31, - kGE, - "b[n, m] = n>=m ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(10, 20)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = (n >= m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(10, 20)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = 1.f; - // } - // } - test_case_nest_loops_fn(n, m, b, 10, 20, 30, 40, kGE, "b[n, m] = 1.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 20, - var_N + 30, - var_N + 40, - kGE, - "b[n, m] = 1.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 20, - var_M + 30, - var_M + 40, - kGE, - "b[n, m] = n>=m ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(10, 31)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = (n <= m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(10, 31)) { - // for(const auto m : c10::irange(30, 40)) { - // b[n, m] = 0.f; - // } - // } - test_case_nest_loops_fn(n, m, b, 10, 31, 30, 40, kLE, "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 31, - var_N + 30, - var_N + 40, - kLE, - "b[n, m] = 0.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 10, - var_N + 31, - var_M + 30, - var_M + 40, - kLE, - "b[n, m] = n<=m ? 0.f : 1.f;"); - - // Before: - // for (const auto n : c10::irange(30, 40)) { - // for(const auto m : c10::irange(10, 20)) { - // b[n, m] = (n <= m) ? 0.f : 1.f; - // } - // } - // After: - // for (const auto n : c10::irange(30, 40)) { - // for(const auto m : c10::irange(10, 20)) { - // b[n, m] = 0.f; - // } - // } - test_case_nest_loops_fn(n, m, b, 30, 40, 10, 20, kLE, "b[n, m] = 1.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 30, - var_N + 40, - var_N + 10, - var_N + 20, - kLE, - "b[n, m] = 1.f;"); - test_case_nest_loops_fn( - n, - m, - b, - var_N + 30, - var_N + 40, - var_M + 10, - var_M + 20, - kLE, - "b[n, m] = n<=m ? 0.f : 1.f;"); -} - -TEST(Simplify, CompareSelectCondAlwaysInLoopBounds) { - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = n < 1 ? 0.f : 1.f; - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 1.f; - // } - constexpr int N = 8; - BufHandle b("b", {N}, kFloat); - VarHandle n("n", kInt); - StmtPtr s = For::make( - n, 1, N, b.store({n}, CompareSelect::make(n, 1, 0.f, 1.0f, kLT))); - s = IRSimplifier::simplify(s); - std::ostringstream oss; - oss << *s; - torch::jit::testing::FileCheck().run( - R"IR( -# CHECK: b[n] = 1.f; -)IR", - oss.str()); -} - -TEST(Simplify, IfThenCondAlwaysInLoopBounds) { - // Before: - // for (const auto n : c10::irange(1, N)) { - // b[n] = IfThenElse(n < 1 ? 1 : 0, 0.f, 1.f); - // } - // After: - // for (const auto n : c10::irange(1, N)) { - // b[n] = 1.f; - // } - constexpr int N = 8; - BufHandle b("b", {N}, kFloat); - VarHandle n("n", kInt); - StmtPtr s = - For::make(n, 1, N, b.store({n}, IfThenElse::make(n < 1, 0.f, 1.0f))); - s = IRSimplifier::simplify(s); - std::ostringstream oss; - oss << *s; - torch::jit::testing::FileCheck().run( - R"IR( -# CHECK: b[n] = 1.f; -)IR", - oss.str()); -} - -TEST(Simplify, MultiClauseCondAlwaysInLoopBounds) { - // This test mimics the unpadded region of a conv2d. We want to remove any - // conditional that is provably satisfied (or unsatisfied) by the entire loop - // range. - // Before: - // for (const auto i : c10::irange(1, 7)) { - // for (const auto j : c10::irange(1, 7)) { - // b[i, j] = IfThenElse( - // j>=7 ? 1 : (i>=7 ? 1 : (j<1 ? 1 : (i<1 ? 1 : 0))), 0.f, 1.f); - // After: - // for (const auto i : c10::irange(1, 7)) { - // for (const auto j : c10::irange(1, 7)) { - // b[i, j] = 1.f; - constexpr int N = 8; - BufHandle b("b", {N, N}, kFloat); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto csel = CompareSelect::make(i, 1, kLT); - csel = CompareSelect::make(j, 1, 1, csel, kLT); - csel = CompareSelect::make(i, N - 1, 1, csel, kGE); - csel = CompareSelect::make(j, N - 1, 1, csel, kGE); - StmtPtr s = b.store({i, j}, IfThenElse::make(csel, 0.f, 1.0f)); - s = For::make(j, 1, N - 1, s); - s = For::make(i, 1, N - 1, s); - s = IRSimplifier::simplify(s); - std::ostringstream oss; - oss << *s; - torch::jit::testing::FileCheck().run( - R"IR( -# CHECK: b[i, j] = 1.f; -)IR", - oss.str()); -} - -TEST(Simplify, DISABLED_SimplifyLoopBounds) { - // This test mimics the padded region of a conv2d. We want to adjust the - // loop bounds such that the condition will be always met. Note that this - // could be solved by peeling, and applying the range-based conditional - // simplification in the previous tests. - // Before: - // for (const auto i : c10::irange(3)) { - // for (const auto j : c10::irange(3)) { - // b[i, j] = (b[i, j]) + (IfThenElse( - // j>=7 ? 1 : (i>=7 ? 1 : (j<1 ? 1 : (i<1 ? 1 : 0))), 0.f, a[i, j])); - // After: - // for (const auto i : c10::irange(1, 3)) { - // for (const auto j : c10::irange(1, 3)) { - // b[i, j] = (b[i, j]) + 1.f; - constexpr int N = 8; - constexpr int K = 3; - BufHandle a("a", {N, N}, kFloat); - BufHandle b("b", {N, N}, kFloat); - VarHandle i("i", kInt); - VarHandle j("j", kInt); - auto csel = CompareSelect::make(i, 1, kLT); - csel = CompareSelect::make(j, 1, 1, csel, kLT); - csel = CompareSelect::make(i, N - 1, 1, csel, kGE); - csel = CompareSelect::make(j, N - 1, 1, csel, kGE); - StmtPtr s = b.store( - {i, j}, b.load({i, j}) + IfThenElse::make(csel, 0.f, a.load({i, j}))); - s = For::make(j, 0, K, s); - s = For::make(i, 0, K, s); - s = IRSimplifier::simplify(s); - std::ostringstream oss; - oss << *s; - torch::jit::testing::FileCheck().run( - R"IR( -# CHECK: for (const auto i : c10::irange(1, 3)) { -# CHECK: for (const auto j : c10::irange(1, 3)) { -# CHECK-NOT: IfThenElse -)IR", - oss.str()); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_te_fuser_pass.cpp b/test/cpp/tensorexpr/test_te_fuser_pass.cpp deleted file mode 100644 index 56535de914e4..000000000000 --- a/test/cpp/tensorexpr/test_te_fuser_pass.cpp +++ /dev/null @@ -1,402 +0,0 @@ -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { - -using namespace torch::jit::tensorexpr; - -struct WithCPUFuser { - WithCPUFuser(bool val = true) : cpuFuserEnabled(canFuseOnCPU()) { - overrideCanFuseOnCPU(val); - } - - ~WithCPUFuser() { - overrideCanFuseOnCPU(cpuFuserEnabled); - } - - bool cpuFuserEnabled; -}; - -TEST(TEFuserPass, FuserPass_1) { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%0 : Float(128, strides=[1], device=cpu), - %1 : Float(128, strides=[1], device=cpu)): - %12 : int = prim::Constant[value=1]() - %2.1 : Float(128, strides=[1], device=cpu) = aten::mul(%0, %1) - %2 : Float(128, strides=[1], device=cpu) = aten::mul(%2.1, %1) - %3 : Float(128, strides=[1], device=cpu) = aten::add_(%2, %1, %12) - %4 : Float(128, strides=[1], device=cpu) = aten::mul(%2, %1) - %5 : Float(128, strides=[1], device=cpu) = aten::add(%2, %4, %12) - return (%5))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g); - - // We should not be able to fuse across the in-place operation here. - testing::FileCheck() - .check("prim::TensorExprGroup_") - ->check("aten::add_") - ->check("prim::TensorExprGroup_") - ->run(*g); -} - -TEST(TEFuserPass, FuserPass_2) { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%0 : Float(128, strides=[1], device=cpu), - %1 : Float(128, strides=[1], device=cpu)): - %12 : int = prim::Constant[value=1]() - %a : Float(128, strides=[1], device=cpu) = aten::mul(%0, %1) - %b : Float(128, strides=[1], device=cpu) = aten::add(%0, %1, %12) - %c : Float(128, strides=[1], device=cpu) = aten::add_(%b, %1, %12) - %d : Float(128, strides=[1], device=cpu) = aten::mul(%c, %a) - return (%d))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g); - - // We should not be able to fuse across the in-place operation here. - testing::FileCheck() - .check("aten::add_") - ->check("prim::TensorExprGroup_0") - ->run(*g); -} - -TEST(TEFuserPass, FuserPass_3) { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Float(128, strides=[1], device=cpu), - %y : Float(128, strides=[1], device=cpu)): - %r : Float(128, strides=[1], device=cpu) = aten::mul(%x, %y) - return (%r))IR"; - { - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 2); - - // We should not create a fusion group since its size would be too small - testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); - } - { - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 1); - - // We should create a fusion group since its size is above the threshold - testing::FileCheck().check("prim::TensorExprGroup")->run(*g); - } -} - -TEST(TEFuserPass, FuserPass_0DimInput) { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Float(device=cpu), - %y : Float(device=cpu)): - %one : int = prim::Constant[value=1]() - %a : Float(device=cpu) = aten::mul(%x, %y) - %b : Float(device=cpu) = aten::add(%x, %a, %one) - return (%b))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g); - - // We should fuse 0-dim tensors too - testing::FileCheck().check("prim::TensorExprGroup")->run(*g); -} - -TEST(TEFuserPass, FuserPass_UnfusibleDevice) { - WithCPUFuser cf(false); - const auto graph_string = R"IR( - graph(%x : Float(10, strides=[1], device=cpu), - %y : Float(10, strides=[1], device=cpu)): - %a : Float(10, strides=[1], device=cpu) = aten::mul(%x, %y) - return (%a))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 1); - - // Test that we're not starting fusion groups from nodes with unfusible device - testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); -} - -TEST(TEFuserPass, FuserPass_UnknownShapes) { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Tensor, - %y : Tensor): - %a : Tensor = aten::mul(%x, %y) - %b : Tensor = aten::mul(%x, %a) - return (%b))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g); - - // Test that we're not generating fusion groups when shapes are not known - testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); -} - -TEST(TEFuserPass, FuserPass_Multidevice) { - { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cpu), - %z : Float(30, strides=[1], device=cpu)): - %dim : int = prim::Constant[value=0]() - %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) - %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) - return (%cat))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 1); - - // We should be able to fuse this - testing::FileCheck().check("prim::TensorExprGroup")->run(*g); - } - { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cuda:0), - %z : Float(30, strides=[1], device=cpu)): - %dim : int = prim::Constant[value=0]() - %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) - %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) - return (%cat))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 1); - - // We should not fuse this aten::cat since its inputs are from different - // devices - testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); - } - { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cpu), - %z : Float(10, strides=[1], device=cuda:0)): - %dim : int = prim::Constant[value=0]() - %xy_list : Tensor[] = prim::ListConstruct(%x, %y) - %xy_cat : Float(30, strides=[1], device=cpu) = aten::cat(%xy_list, %dim) - %r : Float(30, strides=[1], device=cpu) = aten::mul(%xy_cat, %z) - return (%r))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 2); - - // Test that we check device before merging one node (cat) into another - // (mul) - testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); - } - { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cpu), - %z : Float(10, strides=[1], device=cuda:0)): - %z2 : Tensor = aten::mul(%z, %z) - %dim : int = prim::Constant[value=0]() - %xy_list : Tensor[] = prim::ListConstruct(%x, %y, %z2) - %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xy_list, %dim) - return (%cat))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 2); - - // Test that we check device before merging one node (mul) into another - // (cat) - testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); - } - { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Float(10, strides=[1], device=cpu), - %y : Float(20, strides=[1], device=cuda:0)): - %r : Float(10, strides=[1], device=cpu) = aten::mul(%x, %y) - return (%r))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 1); - - // We should not fuse this graph since its inputs are from different devices - testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); - } - { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Float(10, strides=[1], device=cuda:0), - %y : Float(20, strides=[1], device=cuda:1), - %z : Float(20, strides=[1], device=cpu)): - %x2 : Float(10, strides=[1], device=cpu) = aten::mul(%x, %x) - %y2 : Float(10, strides=[1], device=cpu) = aten::mul(%y, %y) - %z2 : Float(10, strides=[1], device=cpu) = aten::mul(%z, %z) - return (%x2, %y2, %z2))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 2); - - // We should not fuse these two computations since they use different - // devices - testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); - } -} - -TEST(TEFuserPass, FuserPass_MergeGroups) { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%a : Float(128, strides=[1], device=cpu), - %b : Float(128, strides=[1], device=cpu)): - %x : Float(128, strides=[1], device=cpu) = aten::mul(%a, %a) - %y : Float(128, strides=[1], device=cpu) = aten::mul(%b, %b) - return (%x, %y))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 1); - - // The %x and %y computations are completely independent and yet we should put - // them into a single fusion group rather than having two separate ones. - testing::FileCheck() - .check("= prim::TensorExprGroup_") - ->check_not("= prim::TensorExprGroup_") - ->run(*g); -} - -TEST(TEFuserPass, FuserPass_IgnoreUnknownShapeAtStart) { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Bool(8, strides=[1], device=cpu), - %y : Bool(8, strides=[1], device=cpu)): - %a : Bool(8, strides=[1], device=cpu) = aten::__and__(%x, %y) - %b : Tensor = aten::__or__(%a, %y) - return (%b) - )IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 2); - testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); -} - -TEST(TEFuserPass, FuserPass_Where) { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Float(8, strides=[1], device=cpu), - %y : Float(8, strides=[1], device=cpu), - %z : Float(8, strides=[1], device=cpu)): - %cond : Bool(8, strides=[1], device=cpu) = aten::eq(%x, %y) - %b : Float(8, strides=[1], device=cpu) = aten::where(%cond, %y, %z) - return (%b) - )IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 2); - testing::FileCheck().check("prim::TensorExprGroup")->run(*g); -} - -TEST(TEFuserPass, FuserPass_WhereList) { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%x : Float(8, strides=[1], device=cpu), - %y : Float(8, strides=[1], device=cpu), - %z : Float(8, strides=[1], device=cpu)): - %cond : Bool(8, strides=[1], device=cpu) = aten::eq(%x, %y) - %b : Tensor[] = aten::where(%cond) - return (%b) - )IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - g->lint(); - FuseTensorExprs(g, /* min_group_size= */ 2); - testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); -} - -TEST(TEFuserPass, DynamicShapeFusion) { - WithCPUFuser cf; - const auto graph_string = R"IR( - graph(%0 : Float(10, 5, strides=[5, 1], device=cpu), - %1 : Float(10, 5, strides=[5, 1], device=cpu)): - %2 : Float(10, 5, strides=[5, 1], device=cpu) = aten::mul(%0, %1) - %3 : Float(10, 5, strides=[5, 1], device=cpu) = aten::mul(%2, %1) - return (%3))IR"; - auto g = std::make_shared(); - torch::jit::parseIR(graph_string, g.get()); - - g->lint(); - FuseTensorExprs( - g, - /* min_group_size = */ 2, - /* add_composed_op = */ true, - /* fuse_to_dynamic_shapes = */ true); - Code code(g, ""); - - testing::FileCheck() - .check("prim::TensorExprDynamicGroup_") - ->check("prim::TensorExprDynamicGuard") - ->check("prim::TensorExprGroup_") - ->run(*g); - - auto run_and_compare = [&](const std::vector& inputs) { - TORCH_INTERNAL_ASSERT(inputs.size() == 2); - - auto ref = at::mul(at::mul(inputs[0], inputs[1]), inputs[1]); - - InterpreterState interp(code); - Stack stack(inputs.begin(), inputs.end()); - interp.run(stack); - at::Tensor out = pop(stack).toTensor(); - ASSERT_TRUE(at::allclose(out, ref)); - }; - - std::vector inputs = {at::rand({10, 5}), at::rand({10, 5})}; - run_and_compare(inputs); - - std::vector inputs2 = {at::rand({20, 5}), at::rand({20, 5})}; - run_and_compare(inputs2); - - std::vector inputs3 = {at::rand({25, 60}), at::rand({25, 60})}; - run_and_compare(inputs3); -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_type.cpp b/test/cpp/tensorexpr/test_type.cpp deleted file mode 100644 index 6758503f4de7..000000000000 --- a/test/cpp/tensorexpr/test_type.cpp +++ /dev/null @@ -1,202 +0,0 @@ -#include - -#include "torch/csrc/jit/tensorexpr/eval.h" -#include "torch/csrc/jit/tensorexpr/ir.h" -#include "torch/csrc/jit/tensorexpr/tensor.h" - -namespace torch { -namespace jit { -using namespace torch::jit::tensorexpr; - -TEST(Type, Test01) { - { - Dtype dt1 = kInt; - ASSERT_EQ(dt1, kInt); - } - { - Dtype dt2_a(kInt, 8); - Dtype dt2_b(kInt, 4); - Dtype dt2_c(ScalarType::Int, 8); - ASSERT_EQ(dt2_a, dt2_c); - ASSERT_NE(dt2_a, dt2_b); - } - { - ASSERT_EQ(kInt, ToDtype()); - ASSERT_EQ(kFloat, ToDtype()); - ASSERT_EQ(kByte, ToDtype()); - ASSERT_EQ(kChar, ToDtype()); - ASSERT_EQ(kShort, ToDtype()); - ASSERT_EQ(kLong, ToDtype()); - ASSERT_EQ(kHalf, ToDtype()); - ASSERT_EQ(kDouble, ToDtype()); - ASSERT_EQ(kBool, ToDtype()); - } - { - Dtype int32x8(kInt, 8); - Dtype float32x8(kFloat, 8); - ASSERT_NE(int32x8, float32x8); - ASSERT_EQ(float32x8, BinaryOpDtype(int32x8, float32x8)); - ASSERT_EQ(float32x8, BinaryOpDtype(float32x8, int32x8)); - ASSERT_EQ(int32x8, BinaryOpDtype(int32x8, int32x8)); - ASSERT_EQ(float32x8, BinaryOpDtype(float32x8, float32x8)); - } -} - -TEST(Type, BitCasting) { - { - VarHandle x("x", kFloat); - ExprHandle y = bitcast(x); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - ASSERT_EQ(y.dtype(), kInt); - } - { - VarHandle x("x", kInt); - ExprHandle y = bitcast(x); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - ASSERT_EQ(y.dtype(), kFloat); - } - { - VarHandle x("x", kShort); - ExprHandle y = bitcast(x); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - ASSERT_EQ(y.dtype(), kHalf); - } - { - VarHandle x("x", kHalf); - ExprHandle y = bitcast(x); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - ASSERT_EQ(y.dtype(), kShort); - } - - constexpr int32_t ref32 = 1337; - constexpr int64_t ref64 = 1337; - constexpr float reff32 = 1337.0f; - constexpr double reff64 = 1337.0f; - using SimpleIRExprEval = ExprEval; - // this is broken - /*{ - constexpr int16_t ref16 = 1337; - at::Half k_; - at::Half* k = &k_; - *reinterpret_cast(k) = ref16; - auto a = HalfImm::make(*k); - auto b = BitCast::make(kShort, a); - SimpleIRExprEval cg(b); - ASSERT_EQ(cg.value(), ref16); - }*/ - - { - float k = raw_bitcast(ref32); - auto a = FloatImm::make(k); - auto b = BitCast::make(kInt, a); - SimpleIRExprEval cg(b); - ASSERT_EQ(cg.value(), ref32); - } - - { - double k = raw_bitcast(ref64); - auto a = DoubleImm::make(k); - auto b = BitCast::make(kLong, a); - SimpleIRExprEval cg(b); - ASSERT_EQ(cg.value(), ref64); - } - - { - int64_t k = raw_bitcast(reff64); - auto a = LongImm::make(k); - auto b = BitCast::make(kDouble, a); - SimpleIRExprEval cg(b); - ASSERT_EQ(cg.value(), reff64); - } - - { - int32_t k = raw_bitcast(reff32); - auto a = IntImm::make(k); - auto b = BitCast::make(kFloat, a); - SimpleIRExprEval cg(b); - ASSERT_EQ(cg.value(), reff32); - } - - // This segfaults :( - /*{ - VarHandle x("x", kDouble); - ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); - } - { - VarHandle x("x", kFloat); - ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); - } - { - VarHandle x("x", kLong); - ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); - } - { - VarHandle x("x", kShort); - ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); - } - { - VarHandle x("x", kInt); - ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); - }*/ -} - -TEST(Type, Propagation) { - // Same types: - { - VarHandle x("x", kFloat); - VarHandle y("y", kFloat); - ExprHandle body = FloatImm::make(2.f) + - (x * FloatImm::make(3.f) + FloatImm::make(4.f) * y); - ASSERT_EQ(body.dtype(), kFloat); - } - // Int to bigger int: - { - VarHandle x("x", kShort); - VarHandle y("y", kLong); - ExprHandle body = - ShortImm::make(2.f) + (x * ShortImm::make(3) + ShortImm::make(4) * y); - ASSERT_EQ(body.dtype(), kLong); - } - // Float to bigger float: - { - VarHandle x("x", kHalf); - VarHandle y("y", kDouble); - ExprHandle body = - HalfImm::make(2.f) + (x * HalfImm::make(3) + HalfImm::make(4) * y); - ASSERT_EQ(body.dtype(), kDouble); - } - // Int to Float: - { - VarHandle x("x", kFloat); - VarHandle y("y", kInt); - ExprHandle body = - IntImm::make(2) + (x * IntImm::make(3) + IntImm::make(4) * y); - ASSERT_EQ(body.dtype(), kFloat); - } - // Smaller float, bigger Int: - { - VarHandle x("x", kHalf); - VarHandle y("y", kLong); - ExprHandle body = - HalfImm::make(2) + (x * HalfImm::make(3) + HalfImm::make(4) * y); - ASSERT_EQ(body.dtype(), kHalf); - } - // Bigger float, smaller Int: - { - VarHandle x("x", kChar); - VarHandle y("y", kDouble); - ExprHandle body = - CharImm::make(2) + (x * CharImm::make(3) + CharImm::make(4) * y); - ASSERT_EQ(body.dtype(), kDouble); - } - // Sign change char/byte upgrades to short: - { - VarHandle x("x", kChar); - VarHandle y("y", kByte); - ExprHandle body = - CharImm::make(2) + (x * CharImm::make(3) + CharImm::make(4) * y); - ASSERT_EQ(body.dtype(), kShort); - } -} -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_type_specializations.cpp b/test/cpp/tensorexpr/test_type_specializations.cpp deleted file mode 100644 index d9756627fa74..000000000000 --- a/test/cpp/tensorexpr/test_type_specializations.cpp +++ /dev/null @@ -1,75 +0,0 @@ -#include - -#include -#include -#include -#include -#include -#include -#include - -// Test that tensor type specializations are available in -// the custom passes - -namespace torch { -namespace jit { - -namespace { - -bool hasTensorTypeSpecializations(torch::jit::Block* block) { - for (Value* v : block->inputs()) { - if (hasTensorTypeSpecialization(v)) - return true; - } - for (Node* n : block->nodes()) { - for (torch::jit::Block* b : n->blocks()) { - if (hasTensorTypeSpecializations(b)) - return true; - } - for (Value* v : n->outputs()) { - if (hasTensorTypeSpecialization(v)) - return true; - } - } - return false; -} - -static bool hasSpecializations = false; -void detectTTSpecializationPass(std::shared_ptr& graph) { - GRAPH_DUMP("In detectTTSpecialization Custom Post Pass: ", graph); - hasSpecializations = hasTensorTypeSpecializations(graph->block()); -} - -} // namespace - -TEST(SpecializationsInCustomPasses, Basic) { - RegisterPass p(detectTTSpecializationPass); - hasSpecializations = false; - std::shared_ptr graph = std::make_shared(); - parseIR( - R"IR( -graph(%a.1 : Tensor, - %b.1 : Tensor): - %c.1 : Tensor = aten::mul(%a.1, %b.1) # misc/test_specializations.py:5:8 - %d.1 : Tensor = aten::mul(%c.1, %b.1) # misc/test_specializations.py:6:8 - return (%d.1) - )IR", - &*graph); - - IValue ival = IValue(torch::randn({22}, at::kCPU)); - std::vector stack = {ival, ival}; - auto run = [&](std::shared_ptr& graph, std::vector stack) { - GraphExecutor executor(graph, ""); - executor.run(stack); - return stack; - }; - run(graph, stack); - - // Profiling mode will not be run with simple executor - if (!getExecutorMode()) { - EXPECT_TRUE(hasSpecializations); - } -} - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/test_utils.h b/test/cpp/tensorexpr/test_utils.h deleted file mode 100644 index 065e513c1a64..000000000000 --- a/test/cpp/tensorexpr/test_utils.h +++ /dev/null @@ -1,78 +0,0 @@ -#pragma once - -#include -#include - -#include -#include -#include - -namespace torch { -namespace jit { -using namespace torch::jit::tensorexpr; - -#define IS_NODE(T, node) \ - { \ - auto node_ = to(node); \ - ASSERT_NE(nullptr, node_); \ - } - -#define IS_NODE_WITH_NAME(T, node, name) \ - auto name = to(node); \ - ASSERT_NE(nullptr, name); - -#define IS_NODE_WITH_NAME_AND_CAST(T, node, name, Type) \ - NodePtr name = nullptr; \ - { \ - auto node_ = to(node); \ - ASSERT_NE(nullptr, node_); \ - ASSERT_EQ(node_->dtype().scalar_type(), ScalarType::Type); \ - name = to(node_->src_value()); \ - } \ - ASSERT_NE(nullptr, name); - -#define IS_IMM_WITH_VAL(T, node, val) \ - { \ - auto node_ = to(node); \ - ASSERT_NE(nullptr, node_); \ - ASSERT_EQ(node_->value(), val); \ - } - -#define IS_VAR_WITH_NAME(node, name) \ - { \ - auto node_ = to(node); \ - ASSERT_NE(nullptr, node_); \ - ASSERT_EQ(node_->name_hint(), name); \ - } - -#define IS_BINOP_W_VARS(T, node, name, v1, v2) \ - NodePtr name = nullptr; \ - { \ - name = to(node); \ - ASSERT_NE(nullptr, name); \ - IS_VAR_WITH_NAME(name->lhs(), v1); \ - IS_VAR_WITH_NAME(name->rhs(), v2); \ - } - -#define IS_BINOP_W_CONST(T, node, name, v, c) \ - NodePtr name = nullptr; \ - { \ - name = to(node); \ - ASSERT_NE(nullptr, name); \ - IS_VAR_WITH_NAME(name->lhs(), v); \ - IS_IMM_WITH_VAL(Int, name->rhs(), c); \ - } - -#define IS_RAND(node) \ - { \ - auto node_ = to(node); \ - ASSERT_NE(nullptr, node_); \ - ASSERT_EQ(node_->op_type(), kRand); \ - } - -void checkIR(StmtPtr s, const std::string& pattern); -void checkExprIR(ExprPtr e, const std::string& pattern); -void checkExprIR(const ExprHandle& e, const std::string& pattern); - -} // namespace jit -} // namespace torch diff --git a/test/cpp/tensorexpr/tutorial.cpp b/test/cpp/tensorexpr/tutorial.cpp deleted file mode 100644 index 3f4c32af463b..000000000000 --- a/test/cpp/tensorexpr/tutorial.cpp +++ /dev/null @@ -1,542 +0,0 @@ -// *** Tensor Expressions *** -// -// This tutorial covers basics of NNC's tensor expressions, shows basic APIs to -// work with them, and outlines how they are used in the overall TorchScript -// compilation pipeline. This doc is permanently a "work in progress" since NNC -// is under active development and things change fast. -// -// This Tutorial's code is compiled in the standard pytorch build, and the -// executable can be found in `build/bin/tutorial_tensorexpr`. -// -// *** What is NNC *** -// -// NNC stands for Neural Net Compiler. It is a component of TorchScript JIT -// and it performs on-the-fly code generation for kernels, which are often a -// combination of multiple aten (torch) operators. -// -// When the JIT interpreter executes a torchscript model, it automatically -// extracts subgraphs from the torchscript IR graph for which specialized code -// can be JIT generated. This usually improves performance as the 'combined' -// kernel created from the subgraph could avoid unnecessary memory traffic that -// is unavoidable when the subgraph is interpreted as-is, operator by operator. -// This optimization is often referred to as 'fusion'. Relatedly, the process of -// finding and extracting subgraphs suitable for NNC code generation is done by -// a JIT pass called 'fuser'. -// -// *** What is TE *** -// -// TE stands for Tensor Expressions. TE is a commonly used approach for -// compiling kernels performing tensor (~matrix) computation. The idea behind it -// is that operators are represented as a mathematical formula describing what -// computation they do (as TEs) and then the TE engine can perform mathematical -// simplification and other optimizations using those formulas and eventually -// generate executable code that would produce the same results as the original -// sequence of operators, but more efficiently. -// -// NNC's design and implementation of TE was heavily inspired by Halide and TVM -// projects. -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace torch::jit::tensorexpr; - -#ifdef TORCH_ENABLE_LLVM - -// Helper function to print a snippet from a big multi-line string -static void printLinesToFrom(const std::string& input_str, int from, int to); - -#endif - -int main(int argc, char* argv[]) { - std::cout << "*** Structure of tensor expressions and statements ***" - << std::endl; - { - // A tensor expression is a tree of expressions. Each expression has a type, - // and that type defines what sub-expressions the current expression has. - // For instance, an expression of type 'Mul' would have a type 'kMul' and - // two subexpressions: LHS and RHS. Each of these two sub-expressions could - // also be a 'Mul' or some other expression. - // - // Let's construct a simple TE: - ExprPtr lhs = alloc(5); - ExprPtr rhs = alloc("x", kInt); - ExprPtr mul = alloc(lhs, rhs); - std::cout << "Tensor expression: " << *mul << std::endl; - // Prints: Tensor expression: 5 * x - - // Here we created an expression representing a 5*x computation, where x is - // an int variable. - - // Another, probably a more convenient, way to construct tensor expressions - // is to use so called expression handles (as opposed to raw expressions - // like we did in the previous example). Expression handles overload common - // operations and allow us to express the same semantics in a more natural - // way: - ExprHandle l = 5; - ExprHandle r = Var::make("x", kInt); - ExprHandle m = l * r; - std::cout << "Tensor expression: " << *m.node() << std::endl; - // Prints: Tensor expression: 5 * x - - // Converting from handles to raw expressions and back is easy: - ExprHandle handle = Var::make("x", kInt); - ExprPtr raw_expr_from_handle = handle.node(); - ExprPtr raw_expr = alloc("x", kInt); - ExprHandle handle_from_raw_expr = ExprHandle(raw_expr); - - // We could construct arbitrarily complex expressions using mathematical - // and logical operations, casts between various data types, and a bunch of - // intrinsics. - ExprHandle a = Var::make("a", kInt); - ExprHandle b = Var::make("b", kFloat); - ExprHandle c = Var::make("c", kFloat); - ExprHandle x = ExprHandle(5) * a + b / (sigmoid(c) - 3.0f); - std::cout << "Tensor expression: " << *x.node() << std::endl; - // Prints: Tensor expression: float(5 * a) + b / ((sigmoid(c)) - 3.f) - - // An ultimate purpose of tensor expressions is to optimize tensor - // computations, and in order to represent accesses to tensors data, there - // is a special kind of expression - a load. - // To construct a load we need two pieces: the base and the indices. The - // base of a load is a Buf expression, which could be thought of as a - // placeholder similar to Var, but with dimensions info. - // - // Let's construct a simple load: - BufHandle A("A", {64, 32}, kInt); - VarPtr i_var = alloc("i", kInt), j_var = alloc("j", kInt); - ExprHandle i(i_var), j(j_var); - ExprHandle load = Load::make(A.dtype(), A, {i, j}); - std::cout << "Tensor expression: " << *load.node() << std::endl; - // Prints: Tensor expression: A[i, j] - - // Tensor Expressions constitute Tensor Statements, which are used to - // represent computation of a given operator or a group of operators from a - // fusion group. - // - // There are three main kinds of tensor statements: - // - block - // - store - // - loop - // - // A Store represents a store to a single element of a tensor (or to a - // group of elements if it's a vectorized store). Store statements, - // similarly to Load expressions, have a base and indices, but on top of - // that they also include a value - an expression representing what needs - // to be stored at the given memory location. Let's create a Store stmt: - StmtPtr store_a = Store::make(A, {i, j}, i + j); - std::cout << "Store statement: " << *store_a << std::endl; - // Prints: Store statement: A[i, j] = i + j; - - // An operator fills the entire tensor, not just a single element, and to - // represent this we need to use For stmt: let's wrap our store stmt with - // two nested loops to represent that variables i and j need to iterate - // over some ranges. - ForPtr loop_j_a = For::make(VarHandle(j_var), 0, 32, store_a); - ForPtr loop_i_a = For::make(VarHandle(i_var), 0, 64, loop_j_a); - - std::cout << "Nested for loops: " << std::endl << *loop_i_a << std::endl; - // Prints: - // Nested for loops: - // for (const auto i : c10::irange(64)) { - // for (const auto j : c10::irange(32)) { - // A[i, j] = i + j; - // } - // } - - // A Block statement is used when we need a sequence of other statements. - // E.g. if a fusion group contains several operators, we initially define - // separate loopnest for each of them and put them all into a common block: - BufHandle B("B", {64, 32}, kInt); - StmtPtr store_b = Store::make(B, {i, j}, A.load(i, j)); - ForPtr loop_j_b = For::make(VarHandle(j_var), 0, 32, store_b); - ForPtr loop_i_b = For::make(VarHandle(i_var), 0, 64, loop_j_b); - - BlockPtr block = Block::make({loop_i_a, loop_i_b}); - std::cout << "Compound Block statement: " << std::endl - << *block << std::endl; - // Prints: - // Compound Block statement: - // { - // for (const auto i : c10::irange(64)) { - // for (const auto j : c10::irange(32)) { - // A[i, j] = i + j; - // } - // } - // for (const auto i : c10::irange(64)) { - // for (const auto j : c10::irange(32)) { - // B[i, j] = A[i, j]; - // } - // } - // } - - // Manually constructing nested loops and blocks to represent a computation - // might be laborious, and instead we can use a 'Compute' API. This API - // requires us to specify dimensions and a lambda to compute a single - // element of the resulting tensor and returns a `Tensor` structure. This - // structure is simply a pair of a buffer that was created to represent the - // result of the computation (BufPtr) and a statement representing the - // computation itself (StmtPtr). - Tensor C = - Compute("C", {64, 32}, [&](const VarHandle& i, const VarHandle& j) { - return i * j; - }); - std::cout << "Stmt produced by 'Compute' API: " << std::endl - << *C.stmt() << std::endl; - // Prints: - // Stmt produced by 'Compute' API: - // for (const auto i : c10::irange(64)) { - // for (const auto j : c10::irange(32)) { - // C[i, j] = i * j; - // } - // } - - // To construct statements to represent computations with reductions, we - // can use a 'Reduce' API - it is similar to 'Compute' but takes a couple - // of extra arguments defining how to perform the reduction. Let's define a - // simple 2D sum of C using that: - Tensor D = Reduce( - "D", - {}, - Sum(), - [&](const VarHandle& i, const VarHandle& j) { return C.load(i, j); }, - {64, 32}); - std::cout << "Stmt produced by 'Reduce' API: " << std::endl - << *D.stmt() << std::endl; - } - - std::cout << "*** Loopnests transformations ***" << std::endl; - { - // When a statement for the computation is generated, we might want to - // apply some optimizations to it. These transformations allow us to end up - // with a statement producing the same results, but more efficiently. - // - // Let's look at a couple of transformations that are used in NNC. We will - // begin with constructing a Block statement like we did before. - - Tensor C = - Compute("C", {64, 32}, [&](const VarHandle& i, const VarHandle& j) { - return i * (j + 1); - }); - BufHandle c_buf(C.buf()); - Tensor D = - Compute("D", {64, 32}, [&](const VarHandle& i, const VarHandle& j) { - return c_buf.load(i, j) - i; - }); - StmtPtr block = Block::make({C.stmt(), D.stmt()}); - std::cout << "Stmt produced by 'Compute' API: " << std::endl - << *block << std::endl; - // Prints: - // Stmt produced by 'Compute' API: - // { - // for (const auto i : c10::irange(64)) { - // for (const auto j : c10::irange(32)) { - // C[i, j] = i * (j + 1); - // } - // } - // for (const auto i_1 : c10::irange(64)) { - // for (const auto j_1 : c10::irange(32)) { - // D[i_1, j_1] = (C[i_1, j_1]) - i_1; - // } - // } - // } - - // One transformation we can apply to this computation is inlining: i.e. - // taking the expression that defines values of C and substituting a load - // from C with it. - // To do that, we first need to create a special object called LoopNest - - // all transformations are methods of this class. To create a loopnest we - // need to provide a list of output buffers and the root statement: - LoopNest nest(block, {D.buf()}); - - // We can always retrieve the Stmt back from LoopNest: - std::cout << "LoopNest root stmt: " << std::endl - << *nest.root_stmt() << std::endl; - // Prints: - // LoopNest root stmt: - // { - // for (const auto i : c10::irange(64)) { - // for (const auto j : c10::irange(32)) { - // C[i, j] = i * (j + 1); - // } - // } - // for (const auto i_1 : c10::irange(64)) { - // for (const auto j_1 : c10::irange(32)) { - // D[i_1, j_1] = (C[i_1, j_1]) - i_1; - // } - // } - // } - - // Now we can apply the inlining transformation: - nest.computeInline(C.buf()); - std::cout << "Stmt after inlining:" << std::endl - << *nest.root_stmt() << std::endl; - // Prints: - // Stmt after inlining: - // { - // for (const auto i : c10::irange(64)) { - // for (const auto j : c10::irange(32)) { - // D[i, j] = i * (j + 1) - i; - // } - // } - // } - - // We can also apply algebraic simplification to a statement: - StmtPtr simplified = IRSimplifier::simplify(nest.root_stmt()); - std::cout << "Stmt after simplification:" << std::endl - << *simplified << std::endl; - // Prints: - // Stmt after simplification: - // { - // for (const auto i : c10::irange(64)) { - // for (const auto j : c10::irange(32)) { - // D[i, j] = i * j; - // } - // } - // } - - // Many loopnest transformations are stateless and can be applied without - // creating a LoopNest object. In fact, we plan to make all transformations - // stateless. - // splitWithTail is one such transformation: it splits an iteration space - // of a given loop into two with a given factor. - ForPtr outer_loop = to(to(simplified)->stmts().front()); - LoopNest::splitWithTail(outer_loop, 13); - // Call simplifier once more to fold some arithmetic. - simplified = IRSimplifier::simplify(simplified); - std::cout << "Stmt after splitWithTail:" << std::endl - << *simplified << std::endl; - // Prints: - // Stmt after splitWithTail: - // { - // for (const auto i_outer : c10::irange(4)) { - // for (const auto i_inner : c10::irange(13)) { - // for (const auto j : c10::irange(32)) { - // D[i_inner + 13 * i_outer, j] = i_inner * j + 13 * (i_outer * j); - // } - // } - // } - // for (const auto i_tail : c10::irange(12)) { - // for (const auto j : c10::irange(32)) { - // D[i_tail + 52, j] = i_tail * j + 52 * j; - // } - // } - // } - - // NNC supports a wide range of loop nest transformations, which we are not - // listing here. Please refer to documentation in - // https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/tensorexpr/loopnest.h - // for more details. - } - - std::cout << "*** Codegen ***" << std::endl; - { - // An ultimate goal of tensor expressions is to be provide a mechanism to - // execute a given computation in the fastest possible way. So far we've - // looked at how we could describe what computation we're interested in, but - // we haven't looked at how to actually execute it. - // - // All we've been dealing with was just symbols with no actual data - // associated, in this section we would look at how we can bridge that gap. - - // Let's start by constructing a simple computation for us to work with: - BufHandle A("A", {64, 32}, kInt); - BufHandle B("B", {64, 32}, kInt); - Tensor X = - Compute("X", {64, 32}, [&](const VarHandle& i, const VarHandle& j) { - return A.load(i, j) + B.load(i, j); - }); - - // And let's lower it to a loop nest, as we did in the previous section. We - // can pass Tensor object directly: - LoopNest loopnest({X}); - std::cout << *loopnest.root_stmt() << std::endl; - // Prints: - // { - // for (const auto i : c10::irange(64)) { - // for (const auto j : c10::irange(32)) { - // X[i, j] = (A[i, j]) + (B[i, j]); - // } - // } - - // Now imagine that we have two actual tensors 64x32 that we want sum - // together, how do we pass those tensors to the computation and how do we - // carry it out? - // - // Codegen object is aimed at providing exactly that functionality. Codegen - // is an abstract class and concrete codegens are derived from it. - // Currently, we have three codegens: - // 1) Simple Evaluator, - // 2) LLVM Codegen for CPU, - // 3) CUDA Codegen. - // In this example we will be using Simple Evaluator, since it's available - // everywhere. - - // To create a codegen, we need to provide the statement - it specifies the - // computation we want to perform - and a list of placeholders and tensors - // used in the computation. The latter part is crucial since that's the only - // way the codegen could use to correlate symbols in the statement to actual - // data arrays that we will be passing when we will actually be performing - // the computation. - // - // Let's create a Simple IR Evaluator codegen for our computation: - SimpleIREvaluator ir_eval(loopnest.root_stmt(), {A, B, X}); - - // We are using the simplest codegen and in it almost no work is done at the - // construction step. Real codegens such as CUDA and LLVM perform - // compilation during that stage so that when we're about to run the - // computation everything is ready. - - // Let's now create some inputs and run our computation with them: - std::vector data_A(64 * 32, 3); // This will be the input A - std::vector data_B(64 * 32, 5); // This will be the input B - std::vector data_X(64 * 32, 0); // This will be used for the result - - // Now let's invoke our codegen to perform the computation on our data. We - // need to provide as many arguments as how many placeholders and tensors we - // passed at the codegen construction time. A position in these lists would - // define how real data arrays from the latter call (these arguments are - // referred to as 'CallArg's in our codebase) correspond to symbols - // (placeholders and tensors) used in the tensor expressions we constructed - // (these are referred to as 'BufferArg'). - // Thus, we will provide three arguments: data_A, data_B, and data_X. data_A - // contains data for the placeholder A, data_B - for the placeholder B, and - // data_X would be used for contents of tensor X. - ir_eval(data_A, data_B, data_X); - - // Let's print one of the elements from each array to verify that the - // computation did happen: - std::cout << "A[10] = " << data_A[10] << std::endl - << "B[10] = " << data_B[10] << std::endl - << "X[10] = A[10] + B[10] = " << data_X[10] << std::endl; - // Prints: - // A[10] = 3 - // B[10] = 5 - // X[10] = A[10] + B[10] = 8 - } - - std::cout << "*** Lowering TorchScript IR to TensorExpr IR ***" << std::endl; - { - // This section requires a LLVM-enabled PyTorch build, so we have to use a - // guard: -#ifdef TORCH_ENABLE_LLVM - - // Often we would like to convert a TorchScript IR to TE rather than - // construct TE IR from scratch. NNC provides an API to perform such - // lowering: it takes a TorchScript graph and returns an object that can be - // used to invoke the generated kernel. - // This API is currently used by the TorchScript JIT fuser and can also be - // used ahead of time to pre-compile parts of a model. - // - // To get familiar with this API let's first start with defining a simple - // TorchScript graph: - const auto graph_string = R"IR( - graph(%A : Float(5, 3, strides=[3, 1], device=cpu), - %B : Float(5, 3, strides=[3, 1], device=cpu)): - %AB : Float(5, 3, strides=[3, 1]) = aten::mul(%A, %B) - %one : int = prim::Constant[value=1]() - %AAB : Float(5, 3, strides=[3, 1]) = aten::mul(%A, %AB) - %AAB_plus_B: Float(5, 3, strides=[3, 1]) = aten::add(%AAB, %B, %one) - return (%AAB_plus_B))IR"; - auto graph = std::make_shared(); - parseIR(graph_string, &*graph); - - // This graph defines a simple computation of A*A*B + B where A and B are - // input 5x3 tensors. - - // To lower this TorchScript graph to TE, we just need to create a - // TensorExprKernel object. In its constructor it constructs the - // corresponding TE IR and compiles it for the given backend (in this - // example for CPU using LLVM compiler). - TensorExprKernel kernel(graph); - - // We can retrieve the generated TE stmt from the kernel object: - StmtPtr kernel_stmt = kernel.getCodeGenStmt(); - std::cout << "TE Stmt constructed from TorchScript: " << std::endl - << *kernel_stmt << std::endl; - // Prints: - // TE Stmt constructed from TorchScript: - // { - // for (const auto v : c10::irange(5)) { - // for (const auto _tail_tail : c10::irange(3)) { - // aten_add[_tail_tail + 3 * v] = (tA[_tail_tail + 3 * v]) * - // ((tA[_tail_tail + 3 * v]) * (tB[_tail_tail + 3 * v])) + - // (tB[_tail_tail + 3 * v]); - // } - // } - // } - - // We can also examine generated LLVM IR and assembly code: - std::cout << "Generated LLVM IR: " << std::endl; - auto ir_str = kernel.getCodeText("ir"); - printLinesToFrom(ir_str, 15, 20); - // Prints: - // Generated LLVM IR: - // %9 = bitcast float* %2 to <8 x float>* - // %10 = load <8 x float>, <8 x float>* %9 ... - // %11 = bitcast float* %5 to <8 x float>* - // %12 = load <8 x float>, <8 x float>* %11 ... - // %13 = fmul <8 x float> %10, %12 - // %14 = fmul <8 x float> %10, %13 - - std::cout << "Generated assembly: " << std::endl; - auto asm_str = kernel.getCodeText("asm"); - printLinesToFrom(asm_str, 10, 15); - // Prints: - // Generated assembly: - // vmulps %ymm1, %ymm0, %ymm2 - // vfmadd213ps %ymm1, %ymm0, %ymm2 - // vmovups %ymm2, (%rax) - // vmovss 32(%rcx), %xmm0 - // vmovss 32(%rdx), %xmm1 - // vmulss %xmm1, %xmm0, %xmm2 - - // We can also execute the generated kernel: - auto A = - at::ones({5, 3}, torch::TensorOptions(torch::kCPU).dtype(at::kFloat)) * - 2.0; - auto B = - at::ones({5, 3}, torch::TensorOptions(torch::kCPU).dtype(at::kFloat)) * - 3.0; - std::vector inputs = {A, B}; - std::vector stack = torch::fmap(inputs); - kernel.run(stack); - auto R = stack[0].toTensor(); - - // Let's print one of the elements from the result tensor to verify that the - // computation did happen and was correct: - std::cout << "R[2][2] = " << R[2][2] << std::endl; - // Prints: - // R[2][2] = 15 - // [ CPUFloatType{} ] -#endif - } - return 0; -} - -void printLinesToFrom(const std::string& input_str, int from, int to) { - std::istringstream f(input_str); - std::string s; - int idx = 0; - while (getline(f, s)) { - if (idx > from) { - std::cout << s << "\n"; - } - if (idx++ > to) { - break; - } - } -} diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index 8d3a8090c67a..c3e26d37da1b 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -2939,7 +2939,10 @@ def f({", ".join(param_names)}): @slowTest @onlyCPU - @ops(op_db, dtypes=OpDTypes.supported) + @ops( + [op for op in op_db if get_name(op) not in known_failures], + dtypes=OpDTypes.supported, + ) def test_nnc_correctness(self, device, dtype, op): if not op.supports_tracing: self.skipTest("Requires tracing support") diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index d5586a5b9cd7..9e408682ca6c 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -1910,7 +1910,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::div, aten_div, [](Node* n) -> SROperator { } auto& out_t = p_node->Output(0).toTensor(); - if (in0_t.sizes() == in1_t.sizes() && + if (te && te->checkInput(in0_t) && in0_t.sizes() == in1_t.sizes() && in0_t.scalar_type() == in1_t.scalar_type() && in0_t.strides() == in1_t.strides() && in0_t.is_contiguous() && in0_t.scalar_type() == at::kFloat) {