Enable tensorexpr cpp tests in CI. try #2 (#35454)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/35454

Differential Revision: D20665160

Pulled By: Krovatkin

fbshipit-source-id: e04cbe92b2ee5a3288f3c4e5c83533bfea85bf85
This commit is contained in:
Nikolay Korovaiko
2020-03-27 12:03:42 -07:00
committed by Facebook GitHub Bot
parent 930d218fbf
commit 9e22d15f14
20 changed files with 779 additions and 531 deletions

View File

@ -1,5 +1,6 @@
#include <c10/util/Exception.h>
#include <test/cpp/jit/tests.h>
#include <test/cpp/tensorexpr/tests.h>
namespace torch {
namespace jit {
@ -25,5 +26,17 @@ JIT_TEST_API void runJITCPPTests(bool runCuda) {
testTorchSaveError();
}
#undef JIT_TEST
#define JIT_TEST(name) test##name();
JIT_TEST_API void runTENSOREXPRCPPTests(bool runCuda) {
TH_FORALL_TENSOREXPR_TESTS(JIT_TEST)
if (runCuda) {
#ifdef USE_CUDA
TH_FORALL_TENSOREXPR_TESTS_CUDA(JIT_TEST)
#endif
}
}
#undef JIT_TEST
} // namespace jit
} // namespace torch

View File

@ -3,6 +3,16 @@ set(TENSOREXPR_TEST_ROOT ${TORCH_ROOT}/test/cpp/tensorexpr)
file(GLOB TENSOREXPR_TEST_SRCS ${TENSOREXPR_TEST_ROOT}/test_*.cpp)
set(TENSOREXPR_TEST_SRCS ${TENSOREXPR_TEST_SRCS} PARENT_SCOPE)
# this is used for running cpp tests from python as part of test_jit.TestJit.test_tensorexpr_cpp
set(TENSOREXPR_TEST_SRCS_WITH_PADDED ${TENSOREXPR_TEST_SRCS} ${TENSOREXPR_TEST_ROOT}/padded_buffer.cpp)
if(NOT USE_CUDA)
list(REMOVE_ITEM TENSOREXPR_TEST_SRCS_WITH_PADDED ${TENSOREXPR_TEST_ROOT}/test_cuda.cpp)
endif()
if(NOT USE_LLVM)
list(REMOVE_ITEM TENSOREXPR_TEST_SRCS_WITH_PADDED ${TENSOREXPR_TEST_ROOT}/test_llvm.cpp)
endif()
set(TENSOREXPR_TEST_SRCS_WITH_PADDED ${TENSOREXPR_TEST_SRCS_WITH_PADDED} PARENT_SCOPE)
add_executable(test_tensorexpr
${TORCH_ROOT}/test/cpp/common/main.cpp
${TENSOREXPR_TEST_ROOT}/gtest.cpp

View File

@ -9,7 +9,7 @@ namespace jit {
TEST(TensorExprTest, name) { \
test##name(); \
}
TH_FORALL_TESTS(TENSOREXPR_GTEST)
TH_FORALL_TENSOREXPR_TESTS(TENSOREXPR_GTEST)
#undef TENSOREXPR_GTEST
#ifdef TORCH_ENABLE_LLVM
@ -17,7 +17,7 @@ TH_FORALL_TESTS(TENSOREXPR_GTEST)
TEST(TensorExprTest, name##_LLVM) { \
test##name(); \
}
TH_FORALL_TESTS_LLVM(TENSOREXPR_GTEST_LLVM)
TH_FORALL_TENSOREXPR_TESTS_LLVM(TENSOREXPR_GTEST_LLVM)
#undef TENSOREXPR_GTEST_LLVM
#endif
@ -26,7 +26,7 @@ TH_FORALL_TESTS_LLVM(TENSOREXPR_GTEST_LLVM)
TEST(TensorExprTest, name##_CUDA) { \
test##name(); \
}
TH_FORALL_TESTS_CUDA(TENSOREXPR_GTEST_CUDA)
TH_FORALL_TENSOREXPR_TESTS_CUDA(TENSOREXPR_GTEST_CUDA)
#undef TENSOREXPR_GTEST_CUDA
#endif

View File

@ -0,0 +1,118 @@
#pragma once
#include <cmath>
// Copyright 2005, Google Inc.
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
// The Google C++ Testing and Mocking Framework (Google Test)
//
// This header file declares functions and macros used internally by
// Google Test. They are subject to change without notice.
using Bits = uint32_t;
// this avoids the "dereferencing type-punned pointer
// will break strict-aliasing rules" error
union Float {
float float_;
Bits bits_;
};
// # of bits in a number.
static const size_t kBitCount = 8*sizeof(Bits);
// The mask for the sign bit.
static const Bits kSignBitMask = static_cast<Bits>(1) << (kBitCount - 1);
// GOOGLETEST_CM0001 DO NOT DELETE
// Converts an integer from the sign-and-magnitude representation to
// the biased representation. More precisely, let N be 2 to the
// power of (kBitCount - 1), an integer x is represented by the
// unsigned number x + N.
//
// For instance,
//
// -N + 1 (the most negative number representable using
// sign-and-magnitude) is represented by 1;
// 0 is represented by N; and
// N - 1 (the biggest number representable using
// sign-and-magnitude) is represented by 2N - 1.
//
// Read http://en.wikipedia.org/wiki/Signed_number_representations
// for more details on signed number representations.
static Bits SignAndMagnitudeToBiased(const Bits &sam) {
if (kSignBitMask & sam) {
// sam represents a negative number.
return ~sam + 1;
} else {
// sam represents a positive number.
return kSignBitMask | sam;
}
}
// Given two numbers in the sign-and-magnitude representation,
// returns the distance between them as an unsigned number.
static Bits DistanceBetweenSignAndMagnitudeNumbers(const Bits &sam1,
const Bits &sam2) {
const Bits biased1 = SignAndMagnitudeToBiased(sam1);
const Bits biased2 = SignAndMagnitudeToBiased(sam2);
return (biased1 >= biased2) ? (biased1 - biased2) : (biased2 - biased1);
}
// How many ULP's (Units in the Last Place) we want to tolerate when
// comparing two numbers. The larger the value, the more error we
// allow. A 0 value means that two numbers must be exactly the same
// to be considered equal.
//
// The maximum error of a single floating-point operation is 0.5
// units in the last place. On Intel CPU's, all floating-point
// calculations are done with 80-bit precision, while double has 64
// bits. Therefore, 4 should be enough for ordinary use.
//
// See the following article for more details on ULP:
// http://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/
static const size_t kMaxUlps = 4;
// Returns true if and only if this number is at most kMaxUlps ULP's away
// from rhs. In particular, this function:
//
// - returns false if either number is (or both are) NAN.
// - treats really large numbers as almost equal to infinity.
// - thinks +0.0 and -0.0 are 0 DLP's apart.
inline bool AlmostEquals(float lhs, float rhs) {
// The IEEE standard says that any comparison operation involving
// a NAN must return false.
if (std::isnan(lhs) || std::isnan(rhs))
return false;
Float l = {lhs};
Float r = {rhs};
return DistanceBetweenSignAndMagnitudeNumbers(l.bits_, r.bits_) <= kMaxUlps;
}

View File

@ -1,9 +1,6 @@
#include "test/cpp/tensorexpr/padded_buffer.h"
#include <sstream>
#include <gtest/gtest.h>
#include <c10/util/Logging.h>
namespace torch {

View File

@ -169,12 +169,20 @@ class PaddedBuffer : public PaddedBufferBase {
// Verify the watermarks in the paddings are intact.
void ValidateWatermark() const {
for (int i = 0; i < kPaddingSize; i++) {
EXPECT_EQ(data_[i], kPaddingValue)
<< "left-side watermark broken: "
<< "index: " << i << ", name: " << name();
EXPECT_EQ(data_[i + total_size_ + kPaddingSize], kPaddingValue)
<< "right-side watermark broken: "
<< "index: " << i << ", name: " << name();
ASSERT_EQ(
data_[i],
kPaddingValue,
"left-side watermark broken: index: ",
i,
", name: ",
name());
ASSERT_EQ(
data_[i + total_size_ + kPaddingSize],
kPaddingValue,
"right-side watermark broken: index: ",
i,
", name: ",
name());
}
}
@ -183,9 +191,13 @@ class PaddedBuffer : public PaddedBufferBase {
DCHECK(backup_data_.size() == data_.size())
<< "Please make sure you have call Backup() before calling CheckBackup()";
for (int i = 0; i < total_size_; i++) {
EXPECT_EQ(data_[i + kPaddingSize], backup_data_[i + kPaddingSize])
<< "mismatch against backup, "
<< "index: " << i << ", name: " << name();
ASSERT_EQ(
data_[i + kPaddingSize],
backup_data_[i + kPaddingSize],
"mismatch against backup, index: ",
i,
", name: ",
name());
}
}
@ -219,8 +231,8 @@ void ExpectAllEqual(const PaddedBuffer<T>& f1, const PaddedBuffer<T>& f2) {
f1.ValidateWatermark();
f2.ValidateWatermark();
for (int i = 0; i < total_size; i++) {
EXPECT_EQ(v1[kPaddingSize + i], v2[kPaddingSize + i])
<< CompareErrorMsg(f1, f2, i);
ASSERT_EQ(
v1[kPaddingSize + i], v2[kPaddingSize + i], CompareErrorMsg(f1, f2, i));
}
}
@ -237,8 +249,11 @@ void ExpectAllNear(
f1.ValidateWatermark();
f2.ValidateWatermark();
for (int i = 0; i < total_size; i++) {
ASSERT_NEAR(v1[kPaddingSize + i], v2[kPaddingSize + i], abs_error);
// << CompareErrorMsg(f1, f2, i);
ASSERT_NEAR(
v1[kPaddingSize + i],
v2[kPaddingSize + i],
abs_error,
CompareErrorMsg(f1, f2, i));
}
}

View File

@ -1,9 +1,10 @@
#include <algorithm>
#include <sstream>
#include <stdexcept>
#include "test/cpp/tensorexpr/test_base.h"
#include <c10/macros/Macros.h>
#include "test/cpp/tensorexpr/padded_buffer.h"
#include "test/cpp/tensorexpr/test_base.h"
#include "torch/csrc/jit/tensorexpr/ir_printer.h"
namespace torch {
@ -34,8 +35,8 @@ void testATen_cast_Float() {
ir_eval(a_v, b_v);
for (int i = 0; i < kTotalSize; ++i) {
EXPECT_EQ(a_v(i), i) << "index: " << i;
EXPECT_EQ(b_v(i), static_cast<float>(i)) << "index: " << i;
ASSERT_EQ(a_v(i), i, "index: ", i);
ASSERT_EQ(b_v(i), static_cast<float>(i), "index: ", i);
}
}
@ -62,8 +63,8 @@ void testATennegInt() {
ir_eval(a_v, b_v);
for (int i = 0; i < kTotalSize; ++i) {
EXPECT_EQ(a_v(i), i) << "index: " << i;
EXPECT_EQ(b_v(i), -static_cast<float>(i)) << "index: " << i;
ASSERT_EQ(a_v(i), i, "index: ", i);
ASSERT_EQ(b_v(i), -static_cast<float>(i), "index: ", i);
}
}
@ -90,8 +91,8 @@ void testATennegFloat() {
ir_eval(a_v, b_v);
for (int i = 0; i < kTotalSize; ++i) {
EXPECT_EQ(a_v(i), i) << "index: " << i;
EXPECT_EQ(b_v(i), -i) << "index: " << i;
ASSERT_EQ(a_v(i), i, "index: ", i);
ASSERT_EQ(b_v(i), -i, "index: ", i);
}
}
@ -125,10 +126,10 @@ void testATenaddInt() {
ir_eval(a_v, b_v, c_v, d_v);
for (int i = 0; i < kTotalSize; ++i) {
EXPECT_EQ(a_v(i), i) << "index: " << i;
EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i;
EXPECT_EQ(c_v(i), 3 * i + 2) << "index: " << i;
EXPECT_EQ(d_v(i), a_v(i) + b_v(i) * c_v(i)) << "index: " << i;
ASSERT_EQ(a_v(i), i, "index: ", i);
ASSERT_EQ(b_v(i), 2 * i + 1, "index: ", i);
ASSERT_EQ(c_v(i), 3 * i + 2, "index: ", i);
ASSERT_EQ(d_v(i), a_v(i) + b_v(i) * c_v(i), "index: ", i);
}
}
@ -162,10 +163,10 @@ void testATenaddFloat() {
ir_eval(a_v, b_v, c_v, d_v);
for (int i = 0; i < kTotalSize; ++i) {
EXPECT_EQ(a_v(i), i) << "index: " << i;
EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i;
EXPECT_EQ(c_v(i), 3 * i + 2) << "index: " << i;
EXPECT_EQ(d_v(i), a_v(i) + b_v(i) * c_v(i)) << "index: " << i;
ASSERT_EQ(a_v(i), i, "index: ", i);
ASSERT_EQ(b_v(i), 2 * i + 1, "index: ", i);
ASSERT_EQ(c_v(i), 3 * i + 2, "index: ", i);
ASSERT_EQ(d_v(i), a_v(i) + b_v(i) * c_v(i), "index: ", i);
}
}
@ -199,10 +200,10 @@ void testATensubInt() {
ir_eval(a_v, b_v, c_v, d_v);
for (int i = 0; i < kTotalSize; ++i) {
EXPECT_EQ(a_v(i), i) << "index: " << i;
EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i;
EXPECT_EQ(c_v(i), 3 * i + 2) << "index: " << i;
EXPECT_EQ(d_v(i), a_v(i) - b_v(i) * c_v(i)) << "index: " << i;
ASSERT_EQ(a_v(i), i, "index: ", i);
ASSERT_EQ(b_v(i), 2 * i + 1, "index: ", i);
ASSERT_EQ(c_v(i), 3 * i + 2, "index: ", i);
ASSERT_EQ(d_v(i), a_v(i) - b_v(i) * c_v(i), "index: ", i);
}
}
@ -236,10 +237,10 @@ void testATensubFloat() {
ir_eval(a_v, b_v, c_v, d_v);
for (int i = 0; i < kTotalSize; ++i) {
EXPECT_EQ(a_v(i), i) << "index: " << i;
EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i;
EXPECT_EQ(c_v(i), 3 * i + 2) << "index: " << i;
EXPECT_EQ(d_v(i), a_v(i) - b_v(i) * c_v(i)) << "index: " << i;
ASSERT_EQ(a_v(i), i, "index: ", i);
ASSERT_EQ(b_v(i), 2 * i + 1, "index: ", i);
ASSERT_EQ(c_v(i), 3 * i + 2, "index: ", i);
ASSERT_EQ(d_v(i), a_v(i) - b_v(i) * c_v(i), "index: ", i);
}
}
@ -274,10 +275,10 @@ void testATenlerp() {
ir_eval(a_v, b_v, c_v, d_v);
for (int i = 0; i < kTotalSize; ++i) {
EXPECT_EQ(a_v(i), i) << "index: " << i;
EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i;
EXPECT_EQ(c_v(i), 3 * i + 2) << "index: " << i;
EXPECT_EQ(d_v(i), a_v(i) + c_v(i) * (b_v(i) - a_v(i))) << "index: " << i;
ASSERT_EQ(a_v(i), i, "index: ", i);
ASSERT_EQ(b_v(i), 2 * i + 1, "index: ", i);
ASSERT_EQ(c_v(i), 3 * i + 2, "index: ", i);
ASSERT_EQ(d_v(i), a_v(i) + c_v(i) * (b_v(i) - a_v(i)), "index: ", i);
}
}
@ -316,11 +317,11 @@ void testATenaddcmulInt() {
ir_eval(a_v, b_v, c_v, d_v, e_v);
for (int i = 0; i < kTotalSize; ++i) {
EXPECT_EQ(a_v(i), i) << "index: " << i;
EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i;
EXPECT_EQ(c_v(i), 3 * i + 2) << "index: " << i;
EXPECT_EQ(d_v(i), 5 * i + 3) << "index: " << i;
EXPECT_EQ(e_v(i), a_v(i) + b_v(i) * c_v(i) * d_v(i)) << "index: " << i;
ASSERT_EQ(a_v(i), i, "index: ", i);
ASSERT_EQ(b_v(i), 2 * i + 1, "index: ", i);
ASSERT_EQ(c_v(i), 3 * i + 2, "index: ", i);
ASSERT_EQ(d_v(i), 5 * i + 3, "index: ", i);
ASSERT_EQ(e_v(i), a_v(i) + b_v(i) * c_v(i) * d_v(i), "index: ", i);
}
}
@ -359,11 +360,11 @@ void testATenaddcmulFloat() {
ir_eval(a_v, b_v, c_v, d_v, e_v);
for (int i = 0; i < kTotalSize; ++i) {
EXPECT_EQ(a_v(i), i) << "index: " << i;
EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i;
EXPECT_EQ(c_v(i), 3 * i + 2) << "index: " << i;
EXPECT_EQ(d_v(i), 5 * i + 3) << "index: " << i;
EXPECT_FLOAT_EQ(e_v(i), a_v(i) + b_v(i) * c_v(i) * d_v(i)) << "index: " << i;
ASSERT_EQ(a_v(i), i, "index: ", i);
ASSERT_EQ(b_v(i), 2 * i + 1, "index: ", i);
ASSERT_EQ(c_v(i), 3 * i + 2, "index: ", i);
ASSERT_EQ(d_v(i), 5 * i + 3, "index: ", i);
ASSERT_FLOAT_EQ(e_v(i), a_v(i) + b_v(i) * c_v(i) * d_v(i), "index: ", i);
}
}
@ -393,9 +394,9 @@ void testATenmulInt() {
ir_eval(a_v, b_v, c_v);
for (int i = 0; i < kTotalSize; ++i) {
EXPECT_EQ(a_v(i), i) << "index: " << i;
EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i;
EXPECT_EQ(c_v(i), a_v(i) * b_v(i)) << "index: " << i;
ASSERT_EQ(a_v(i), i, "index: ", i);
ASSERT_EQ(b_v(i), 2 * i + 1, "index: ", i);
ASSERT_EQ(c_v(i), a_v(i) * b_v(i), "index: ", i);
}
}
@ -425,9 +426,9 @@ void testATenmulFloat() {
ir_eval(a_v, b_v, c_v);
for (int i = 0; i < kTotalSize; ++i) {
EXPECT_EQ(a_v(i), i) << "index: " << i;
EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i;
EXPECT_EQ(c_v(i), a_v(i) * b_v(i)) << "index: " << i;
ASSERT_EQ(a_v(i), i, "index: ", i);
ASSERT_EQ(b_v(i), 2 * i + 1, "index: ", i);
ASSERT_EQ(c_v(i), a_v(i) * b_v(i), "index: ", i);
}
}
@ -457,9 +458,9 @@ void testATendivInt() {
ir_eval(a_v, b_v, c_v);
for (int i = 0; i < kTotalSize; ++i) {
EXPECT_EQ(a_v(i), 2 * i + 1) << "index: " << i;
EXPECT_EQ(b_v(i), i + 1) << "index: " << i;
EXPECT_EQ(c_v(i), a_v(i) / b_v(i)) << "index: " << i;
ASSERT_EQ(a_v(i), 2 * i + 1, "index: ", i);
ASSERT_EQ(b_v(i), i + 1, "index: ", i);
ASSERT_EQ(c_v(i), a_v(i) / b_v(i), "index: ", i);
}
}
@ -489,9 +490,9 @@ void testATendivFloat() {
ir_eval(a_v, b_v, c_v);
for (int i = 0; i < kTotalSize; ++i) {
EXPECT_EQ(a_v(i), 2 * i + 1) << "index: " << i;
EXPECT_EQ(b_v(i), i + 1) << "index: " << i;
EXPECT_EQ(c_v(i), a_v(i) / b_v(i)) << "index: " << i;
ASSERT_EQ(a_v(i), 2 * i + 1, "index: ", i);
ASSERT_EQ(b_v(i), i + 1, "index: ", i);
ASSERT_EQ(c_v(i), a_v(i) / b_v(i), "index: ", i);
}
}
@ -521,9 +522,9 @@ void testATenmaxInt() {
ir_eval(a_v, b_v, c_v);
for (int i = 0; i < kTotalSize; ++i) {
EXPECT_EQ(a_v(i), i) << "index: " << i;
EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i;
EXPECT_EQ(c_v(i), std::max(a_v(i), b_v(i))) << "index: " << i;
ASSERT_EQ(a_v(i), i, "index: ", i);
ASSERT_EQ(b_v(i), 2 * i + 1, "index: ", i);
ASSERT_EQ(c_v(i), std::max(a_v(i), b_v(i)), "index: ", i);
}
}
@ -553,9 +554,9 @@ void testATenmaxFloat() {
ir_eval(a_v, b_v, c_v);
for (int i = 0; i < kTotalSize; ++i) {
EXPECT_EQ(a_v(i), i) << "index: " << i;
EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i;
EXPECT_EQ(c_v(i), std::fmax(a_v(i), b_v(i))) << "index: " << i;
ASSERT_EQ(a_v(i), i, "index: ", i);
ASSERT_EQ(b_v(i), 2 * i + 1, "index: ", i);
ASSERT_EQ(c_v(i), std::fmax(a_v(i), b_v(i)), "index: ", i);
}
}
@ -585,9 +586,9 @@ void testATenminInt() {
ir_eval(a_v, b_v, c_v);
for (int i = 0; i < kTotalSize; ++i) {
EXPECT_EQ(a_v(i), i) << "index: " << i;
EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i;
EXPECT_EQ(c_v(i), std::min(a_v(i), b_v(i))) << "index: " << i;
ASSERT_EQ(a_v(i), i, "index: ", i);
ASSERT_EQ(b_v(i), 2 * i + 1, "index: ", i);
ASSERT_EQ(c_v(i), std::min(a_v(i), b_v(i)), "index: ", i);
}
}
@ -617,9 +618,9 @@ void testATenminFloat() {
ir_eval(a_v, b_v, c_v);
for (int i = 0; i < kTotalSize; ++i) {
EXPECT_EQ(a_v(i), i) << "index: " << i;
EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i;
EXPECT_EQ(c_v(i), std::fmin(a_v(i), b_v(i))) << "index: " << i;
ASSERT_EQ(a_v(i), i, "index: ", i);
ASSERT_EQ(b_v(i), 2 * i + 1, "index: ", i);
ASSERT_EQ(c_v(i), std::fmin(a_v(i), b_v(i)), "index: ", i);
}
}
@ -650,9 +651,9 @@ void testATen_sigmoid_backward() {
ir_eval(a_v, b_v, c_v);
for (int i = 0; i < kTotalSize; ++i) {
EXPECT_EQ(a_v(i), i) << "index: " << i;
EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i;
EXPECT_EQ(c_v(i), a_v(i) * b_v(i) * (1.0f - b_v(i))) << "index: " << i;
ASSERT_EQ(a_v(i), i, "index: ", i);
ASSERT_EQ(b_v(i), 2 * i + 1, "index: ", i);
ASSERT_EQ(c_v(i), a_v(i) * b_v(i) * (1.0f - b_v(i)), "index: ", i);
}
}
@ -683,13 +684,13 @@ void testATen_tanh_backward() {
ir_eval(a_v, b_v, c_v);
for (int i = 0; i < kTotalSize; ++i) {
EXPECT_EQ(a_v(i), i) << "index: " << i;
EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i;
EXPECT_EQ(c_v(i), a_v(i) * (1.0f - (b_v(i) * b_v(i)))) << "index: " << i;
ASSERT_EQ(a_v(i), i, "index: ", i);
ASSERT_EQ(b_v(i), 2 * i + 1, "index: ", i);
ASSERT_EQ(c_v(i), a_v(i) * (1.0f - (b_v(i) * b_v(i))), "index: ", i);
}
}
void testATenreciprocal() {
void __ubsan_ignore_float_divide_by_zero__ testATenreciprocal() {
KernelScope kernel_scope;
const int kTotalSize = 128;
Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
@ -711,8 +712,8 @@ void testATenreciprocal() {
ir_eval(a_v, b_v);
for (int i = 0; i < kTotalSize; ++i) {
EXPECT_EQ(a_v(i), i) << "index: " << i;
EXPECT_EQ(b_v(i), 1.0f / i) << "index: " << i;
ASSERT_EQ(a_v(i), i, "index: ", i);
ASSERT_EQ(b_v(i), 1.0f / i, "index: ", i);
}
}
@ -738,8 +739,8 @@ void testATenreluInt() {
ir_eval(a_v, b_v);
for (int i = 0; i < kTotalSize; ++i) {
EXPECT_EQ(a_v(i), i - 64) << "index: " << i;
EXPECT_EQ(b_v(i), std::max(a_v(i), 0)) << "index: " << i;
ASSERT_EQ(a_v(i), i - 64, "index: ", i);
ASSERT_EQ(b_v(i), std::max(a_v(i), 0), "index: ", i);
}
}
@ -769,8 +770,8 @@ void testATenreluFloat() {
ir_eval(a_v, b_v);
for (int i = 0; i < kTotalSize; ++i) {
EXPECT_EQ(a_v(i), i - 64) << "index: " << i;
EXPECT_EQ(b_v(i), std::fmax(a_v(i), 0)) << "index: " << i;
ASSERT_EQ(a_v(i), i - 64, "index: ", i);
ASSERT_EQ(b_v(i), std::fmax(a_v(i), 0), "index: ", i);
}
}
@ -796,8 +797,8 @@ void testATenlogFloat() {
ir_eval(a_v, b_v);
for (int i = 0; i < kTotalSize; ++i) {
EXPECT_EQ(a_v(i), i + 10) << "index: " << i;
EXPECT_EQ(b_v(i), std::log(a_v(i))) << "index: " << i;
ASSERT_EQ(a_v(i), i + 10, "index: ", i);
ASSERT_EQ(b_v(i), std::log(a_v(i)), "index: ", i);
}
}
@ -823,8 +824,8 @@ void testATenlog10Float() {
ir_eval(a_v, b_v);
for (int i = 0; i < kTotalSize; ++i) {
EXPECT_EQ(a_v(i), i + 10) << "index: " << i;
EXPECT_EQ(b_v(i), std::log10(a_v(i))) << "index: " << i;
ASSERT_EQ(a_v(i), i + 10, "index: ", i);
ASSERT_EQ(b_v(i), std::log10(a_v(i)), "index: ", i);
}
}
@ -850,8 +851,8 @@ void testATenlog2Float() {
ir_eval(a_v, b_v);
for (int i = 0; i < kTotalSize; ++i) {
EXPECT_EQ(a_v(i), i + 10) << "index: " << i;
EXPECT_EQ(b_v(i), std::log2(a_v(i))) << "index: " << i;
ASSERT_EQ(a_v(i), i + 10, "index: ", i);
ASSERT_EQ(b_v(i), std::log2(a_v(i)), "index: ", i);
}
}
@ -877,8 +878,8 @@ void testATenexpFloat() {
ir_eval(a_v, b_v);
for (int i = 0; i < kTotalSize; ++i) {
EXPECT_EQ(a_v(i), i / 10.0f) << "index: " << i;
EXPECT_EQ(b_v(i), std::exp(a_v(i))) << "index: " << i;
ASSERT_EQ(a_v(i), i / 10.0f, "index: ", i);
ASSERT_EQ(b_v(i), std::exp(a_v(i)), "index: ", i);
}
}
@ -904,8 +905,8 @@ void testATenerfFloat() {
ir_eval(a_v, b_v);
for (int i = 0; i < kTotalSize; ++i) {
EXPECT_EQ(a_v(i), i / 10.0f) << "index: " << i;
EXPECT_EQ(b_v(i), std::erf(a_v(i))) << "index: " << i;
ASSERT_EQ(a_v(i), i / 10.0f, "index: ", i);
ASSERT_EQ(b_v(i), std::erf(a_v(i)), "index: ", i);
}
}
@ -931,8 +932,8 @@ void testATencosFloat() {
ir_eval(a_v, b_v);
for (int i = 0; i < kTotalSize; ++i) {
EXPECT_EQ(a_v(i), i / 10.0f) << "index: " << i;
EXPECT_EQ(b_v(i), std::cos(a_v(i))) << "index: " << i;
ASSERT_EQ(a_v(i), i / 10.0f, "index: ", i);
ASSERT_EQ(b_v(i), std::cos(a_v(i)), "index: ", i);
}
}

View File

@ -1,7 +1,45 @@
#pragma once
#if defined(USE_GTEST)
#include <gtest/gtest.h>
#include <test/cpp/common/support.h>
#else
#include "c10/util/Exception.h"
#include "test/cpp/tensorexpr/gtest_assert_float_eq.h"
#include <cmath>
#define ASSERT_EQ(x, y, ...) TORCH_INTERNAL_ASSERT((x) == (y), __VA_ARGS__)
#define ASSERT_FLOAT_EQ(x, y, ...) \
TORCH_INTERNAL_ASSERT(AlmostEquals((x), (y)), __VA_ARGS__)
#define ASSERT_NE(x, y, ...) TORCH_INTERNAL_ASSERT((x) != (y), __VA_ARGS__)
#define ASSERT_GT(x, y, ...) TORCH_INTERNAL_ASSERT((x) > (y), __VA_ARGS__)
#define ASSERT_GE(x, y, ...) TORCH_INTERNAL_ASSERT((x) >= (y), __VA_ARGS__)
#define ASSERT_LT(x, y, ...) TORCH_INTERNAL_ASSERT((x) < (y), __VA_ARGS__)
#define ASSERT_LE(x, y, ...) TORCH_INTERNAL_ASSERT((x) <= (y), __VA_ARGS__)
#define ASSERT_NEAR(x, y, a, ...) \
TORCH_INTERNAL_ASSERT(std::fabs((x) - (y)) < (a), __VA_ARGS__)
#define ASSERT_TRUE TORCH_INTERNAL_ASSERT
#define ASSERT_FALSE(x) ASSERT_TRUE(!(x))
#define ASSERT_THROWS_WITH(statement, substring) \
try { \
(void)statement; \
ASSERT_TRUE(false); \
} catch (const std::exception& e) { \
ASSERT_NE(std::string(e.what()).find(substring), std::string::npos); \
}
#define ASSERT_ANY_THROW(statement) \
{ \
bool threw = false; \
try { \
(void)statement; \
} catch (const std::exception& e) { \
threw = true; \
} \
ASSERT_TRUE(threw); \
}
#endif // defined(USE_GTEST)
namespace torch {
namespace jit {
@ -15,8 +53,8 @@ void ExpectAllNear(
const std::string& name = "") {
ASSERT_EQ(v1.size(), v2.size());
for (int i = 0; i < v1.size(); i++) {
EXPECT_NEAR(v1[i], v2[i], threshold)
<< "element index: " << i << ", name: " << name;
ASSERT_NEAR(
v1[i], v2[i], threshold, "element index: ", i, ", name: ", name);
}
}

View File

@ -261,7 +261,7 @@ void testCudaTestRand01() {
sum1 += v;
sum2 += v * v;
sum3 += v * v * v;
EXPECT_TRUE(v >= 0 && v < 1) << "invalid value: " << i << ", " << v;
ASSERT_TRUE(v >= 0 && v < 1, "invalid value: ", i, ", ", v);
}
sum1 /= N;
sum2 /= N;
@ -270,9 +270,9 @@ void testCudaTestRand01() {
float sum2_mean = 1.f / 3;
float sum3_mean = 1.f / 4;
EXPECT_NEAR(sum1, sum1_mean, 2e-2);
EXPECT_NEAR(sum2, sum2_mean, 2e-2);
EXPECT_NEAR(sum3, sum3_mean, 2e-2);
ASSERT_NEAR(sum1, sum1_mean, 2e-2);
ASSERT_NEAR(sum2, sum2_mean, 2e-2);
ASSERT_NEAR(sum3, sum3_mean, 2e-2);
cudaFree(c_dev);
}

View File

@ -27,7 +27,7 @@ void testExprBasicValueTest() {
ExprHandle a = IntImm::make(2), b = IntImm::make(3);
ExprHandle c = Add::make(a, b);
SimpleIRExprEval eval(c);
EXPECT_EQ(eval.value<int>(), 5);
ASSERT_EQ(eval.value<int>(), 5);
}
void testExprBasicValueTest02() {
@ -38,7 +38,7 @@ void testExprBasicValueTest02() {
ExprHandle d(5.0f);
ExprHandle f = (a + b) - (c + d);
SimpleIRExprEval eval(f);
EXPECT_EQ(eval.value<float>(), -4.0f);
ASSERT_EQ(eval.value<float>(), -4.0f);
}
void testExprLetTest01() {
@ -48,7 +48,7 @@ void testExprLetTest01() {
ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f));
ExprHandle result = Let::make(x, ExprHandle(3.f), body);
SimpleIRExprEval eval(result);
EXPECT_EQ(eval.value<float>(), 2 + (3 * 3 + 4));
ASSERT_EQ(eval.value<float>(), 2 + (3 * 3 + 4));
}
void testExprLetTest02() {
@ -61,7 +61,7 @@ void testExprLetTest02() {
ExprHandle e1 = Let::make(x, ExprHandle(3.f), body);
ExprHandle e2 = Let::make(y, ExprHandle(6.f), e1);
SimpleIRExprEval eval(e2);
EXPECT_EQ(eval.value<float>(), 2 + (3 * 3 + 4 * 6));
ASSERT_EQ(eval.value<float>(), 2 + (3 * 3 + 4 * 6));
}
void testExprLetStmtTest01() {
@ -97,7 +97,7 @@ void testExprIntTest() {
ExprHandle body = ExprHandle(2) + (x * ExprHandle(3) + ExprHandle(4));
ExprHandle result = Let::make(x, ExprHandle(3), body);
SimpleIRExprEval eval(result);
EXPECT_EQ(eval.value<int>(), 2 + (3 * 3 + 4));
ASSERT_EQ(eval.value<int>(), 2 + (3 * 3 + 4));
}
void testExprFloatTest() {
@ -108,7 +108,7 @@ void testExprFloatTest() {
ExprHandle((float)2) + (x * ExprHandle((float)3) + ExprHandle((float)4));
ExprHandle result = Let::make(x, ExprHandle((float)3), body);
SimpleIRExprEval eval(result);
EXPECT_EQ(eval.value<float>(), 2 + (3 * 3 + 4));
ASSERT_EQ(eval.value<float>(), 2 + (3 * 3 + 4));
}
void testExprByteTest() {
@ -119,7 +119,7 @@ void testExprByteTest() {
(x * ExprHandle((uint8_t)3) + ExprHandle((uint8_t)4));
ExprHandle result = Let::make(x, ExprHandle((uint8_t)3), body);
SimpleIRExprEval eval(result);
EXPECT_EQ(eval.value<uint8_t>(), 2 + (3 * 3 + 4));
ASSERT_EQ(eval.value<uint8_t>(), 2 + (3 * 3 + 4));
}
void testExprCharTest() {
@ -130,7 +130,7 @@ void testExprCharTest() {
(x * ExprHandle((int8_t)3) + ExprHandle((int8_t)4));
ExprHandle result = Let::make(x, ExprHandle((int8_t)3), body);
SimpleIRExprEval eval(result);
EXPECT_EQ(eval.value<int8_t>(), 2 + (3 * 3 + 4));
ASSERT_EQ(eval.value<int8_t>(), 2 + (3 * 3 + 4));
}
void testExprShortTest() {
@ -141,7 +141,7 @@ void testExprShortTest() {
(x * ExprHandle((int16_t)3) + ExprHandle((int16_t)4));
ExprHandle result = Let::make(x, ExprHandle((int16_t)3), body);
SimpleIRExprEval eval(result);
EXPECT_EQ(eval.value<int16_t>(), 2 + (3 * 3 + 4));
ASSERT_EQ(eval.value<int16_t>(), 2 + (3 * 3 + 4));
}
void testExprLongTest() {
@ -152,7 +152,7 @@ void testExprLongTest() {
(x * ExprHandle((int64_t)3) + ExprHandle((int64_t)4));
ExprHandle result = Let::make(x, ExprHandle((int64_t)3), body);
SimpleIRExprEval eval(result);
EXPECT_EQ(eval.value<int64_t>(), 2 + (3 * 3 + 4));
ASSERT_EQ(eval.value<int64_t>(), 2 + (3 * 3 + 4));
}
void testExprHalfTest() {
@ -163,7 +163,7 @@ void testExprHalfTest() {
(x * ExprHandle((at::Half)3) + ExprHandle((at::Half)4));
ExprHandle result = Let::make(x, ExprHandle((at::Half)3), body);
SimpleIRExprEval eval(result);
EXPECT_EQ(eval.value<at::Half>(), 2 + (3 * 3 + 4));
ASSERT_EQ(eval.value<at::Half>(), 2 + (3 * 3 + 4));
}
void testExprDoubleTest() {
@ -174,7 +174,7 @@ void testExprDoubleTest() {
(x * ExprHandle((double)3) + ExprHandle((double)4));
ExprHandle result = Let::make(x, ExprHandle((double)3), body);
SimpleIRExprEval eval(result);
EXPECT_EQ(eval.value<double>(), 2 + (3 * 3 + 4));
ASSERT_EQ(eval.value<double>(), 2 + (3 * 3 + 4));
}
void testExprVectorAdd01() {
KernelScope kernel_scope;
@ -211,9 +211,9 @@ void testExprVectorAdd01() {
Broadcast::make(1, kVectorSize));
Stmt* stmt = For::make(index, 0, kVectorCount, store_c);
EXPECT_EQ(load_a.dtype(), Dtype(kFloat, kVectorSize));
EXPECT_EQ(load_b.dtype(), Dtype(kFloat, kVectorSize));
EXPECT_EQ(value.dtype(), Dtype(kFloat, kVectorSize));
ASSERT_EQ(load_a.dtype(), Dtype(kFloat, kVectorSize));
ASSERT_EQ(load_b.dtype(), Dtype(kFloat, kVectorSize));
ASSERT_EQ(value.dtype(), Dtype(kFloat, kVectorSize));
PaddedBuffer<float> a_v(kTotalSize);
PaddedBuffer<float> b_v(kTotalSize);
@ -359,7 +359,7 @@ void testExprUnaryMath01() {
ExprHandle v = test_config.func(ExprHandle(input_v));
float v_ref = test_config.ref_func(input_v);
SimpleIRExprEval eval(v);
EXPECT_NEAR(eval.value<float>(), v_ref, 1e-6) << "fail: " << v;
ASSERT_NEAR(eval.value<float>(), v_ref, 1e-6, "fail: ", v);
}
}
@ -383,7 +383,7 @@ void testExprBinaryMath01() {
ExprHandle v_expr = test_config.func(ExprHandle(v1), ExprHandle(v2));
float v_ref = test_config.ref_func(v1, v2);
SimpleIRExprEval eval(v_expr);
EXPECT_NEAR(eval.value<float>(), v_ref, 1e-6) << "fail: " << v_expr;
ASSERT_NEAR(eval.value<float>(), v_ref, 1e-6, "fail: ", v_expr);
}
}
@ -396,7 +396,7 @@ void testExprBitwiseOps() {
ExprHandle f = (((a ^ (b << 1)) & c) >> 2) | d;
SimpleIRExprEval eval(f);
EXPECT_EQ(eval.value<int>(), 11);
ASSERT_EQ(eval.value<int>(), 11);
}
void testExprDynamicShapeAdd() {

View File

@ -18,7 +18,7 @@ void testIRPrinterBasicValueTest() {
std::stringstream ss;
ss << c;
EXPECT_EQ(ss.str(), "2 + 3");
ASSERT_EQ(ss.str(), "2 + 3");
}
void testIRPrinterBasicValueTest02() {
@ -31,7 +31,7 @@ void testIRPrinterBasicValueTest02() {
std::stringstream ss;
ss << f;
EXPECT_EQ(ss.str(), "(2.f + 3.f) - (4.f + 5.f)");
ASSERT_EQ(ss.str(), "(2.f + 3.f) - (4.f + 5.f)");
}
void testIRPrinterLetTest01() {
@ -43,7 +43,7 @@ void testIRPrinterLetTest01() {
std::stringstream ss;
ss << result;
EXPECT_EQ(ss.str(), "let x = 3.f in 2.f + (x * 3.f + 4.f)");
ASSERT_EQ(ss.str(), "let x = 3.f in 2.f + (x * 3.f + 4.f)");
}
void testIRPrinterLetTest02() {
@ -58,7 +58,7 @@ void testIRPrinterLetTest02() {
std::stringstream ss;
ss << e2;
EXPECT_EQ(
ASSERT_EQ(
ss.str(), "let y = 6.f in (let x = 3.f in 2.f + (x * 3.f + 4.f * y))");
}
@ -74,7 +74,7 @@ void testIRPrinterCastTest() {
std::stringstream ss;
ss << e2;
EXPECT_EQ(
ASSERT_EQ(
ss.str(),
"let y = 6.f in (let x = int(3.f) in 2.f + (x * 3.f + 4.f * y))");
}

View File

@ -38,9 +38,9 @@ using LLVMExprEval = ExprEval<LLVMCodeGen>;
auto a = Name##Imm::make(Val); \
LLVMExprEval cg(a); \
if (std::is_floating_point<decltype(Val)>()) { \
EXPECT_NEAR(cg.value<Type>(), Val, 0.1); \
ASSERT_NEAR(cg.value<Type>(), Val, 0.1); \
} else { \
EXPECT_EQ(cg.value<Type>(), Val); \
ASSERT_EQ(cg.value<Type>(), Val); \
} \
}
TEST_LLVM_SCALAR_TYPES(IMM_TEST)
@ -54,9 +54,9 @@ TEST_LLVM_SCALAR_TYPES(IMM_TEST)
auto c = Add::make(a, b); \
LLVMExprEval cg(c); \
if (std::is_floating_point<decltype(Val)>()) { \
EXPECT_NEAR(cg.value<Type>(), Val * 3, 0.1); \
ASSERT_NEAR(cg.value<Type>(), Val * 3, 0.1); \
} else { \
EXPECT_EQ(cg.value<Type>(), Val * 3); \
ASSERT_EQ(cg.value<Type>(), Val * 3); \
} \
}
TEST_LLVM_SCALAR_TYPES(ADD_TEST)
@ -70,9 +70,9 @@ TEST_LLVM_SCALAR_TYPES(ADD_TEST)
auto c = Sub::make(a, b); \
LLVMExprEval cg(c); \
if (std::is_floating_point<decltype(Val)>()) { \
EXPECT_NEAR(cg.value<Type>(), Val, 0.1); \
ASSERT_NEAR(cg.value<Type>(), Val, 0.1); \
} else { \
EXPECT_EQ(cg.value<Type>(), Val); \
ASSERT_EQ(cg.value<Type>(), Val); \
} \
}
TEST_LLVM_SCALAR_TYPES(SUB_TEST)
@ -86,9 +86,9 @@ TEST_LLVM_SCALAR_TYPES(SUB_TEST)
auto c = Mul::make(a, b); \
LLVMExprEval cg(c); \
if (std::is_floating_point<decltype(Val)>()) { \
EXPECT_NEAR(cg.value<Type>(), Val * 4, 0.1); \
ASSERT_NEAR(cg.value<Type>(), Val * 4, 0.1); \
} else { \
EXPECT_EQ(cg.value<Type>(), Val * 4); \
ASSERT_EQ(cg.value<Type>(), Val * 4); \
} \
}
TEST_LLVM_SCALAR_TYPES(MUL_TEST)
@ -102,9 +102,9 @@ TEST_LLVM_SCALAR_TYPES(MUL_TEST)
auto c = Div::make(a, b); \
LLVMExprEval cg(c); \
if (std::is_floating_point<decltype(Val)>()) { \
EXPECT_NEAR(cg.value<Type>(), 2, 0.1); \
ASSERT_NEAR(cg.value<Type>(), 2, 0.1); \
} else { \
EXPECT_EQ(cg.value<Type>(), 2); \
ASSERT_EQ(cg.value<Type>(), 2); \
} \
}
TEST_LLVM_SCALAR_TYPES(DIV_TEST)
@ -115,7 +115,7 @@ void testLLVMIntToFloatCastTest() {
auto a = IntImm::make(2);
auto b = Cast::make(kFloat, a);
LLVMExprEval cg(b, {});
EXPECT_EQ(cg.value<float>(), 2.0);
ASSERT_EQ(cg.value<float>(), 2.0);
}
void testLLVMFloatToIntCastTest() {
@ -123,7 +123,7 @@ void testLLVMFloatToIntCastTest() {
auto a = FloatImm::make(2.0);
auto b = Cast::make(kInt, a);
LLVMExprEval cg(b);
EXPECT_EQ(cg.value<int>(), 2);
ASSERT_EQ(cg.value<int>(), 2);
}
void testLLVMIntToLongCastTest() {
@ -131,7 +131,7 @@ void testLLVMIntToLongCastTest() {
auto a = IntImm::make(12345);
auto b = Cast::make(kLong, a);
LLVMExprEval cg(b);
EXPECT_EQ(cg.value<int64_t>(), 12345);
ASSERT_EQ(cg.value<int64_t>(), 12345);
}
void testLLVMByteToCharCastTest() {
@ -139,7 +139,7 @@ void testLLVMByteToCharCastTest() {
auto a = ByteImm::make(250);
auto b = Cast::make(kChar, a);
LLVMExprEval cg(b);
EXPECT_EQ(cg.value<int8_t>(), (int8_t)250);
ASSERT_EQ(cg.value<int8_t>(), (int8_t)250);
}
void testLLVMHalfToLongCastTest() {
@ -147,7 +147,7 @@ void testLLVMHalfToLongCastTest() {
auto a = HalfImm::make(2.0);
auto b = Cast::make(kLong, a);
LLVMExprEval cg(b);
EXPECT_EQ(cg.value<int64_t>(), 2);
ASSERT_EQ(cg.value<int64_t>(), 2);
}
void testLLVMByteToDoubleCastTest() {
@ -155,7 +155,7 @@ void testLLVMByteToDoubleCastTest() {
auto a = ByteImm::make(2);
auto b = Cast::make(kDouble, a);
LLVMExprEval cg(b);
EXPECT_EQ(cg.value<double>(), 2);
ASSERT_EQ(cg.value<double>(), 2);
}
void testLLVMLetTest01() {
@ -165,7 +165,7 @@ void testLLVMLetTest01() {
ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f));
ExprHandle result = Let::make(x, ExprHandle(3.f), body);
LLVMExprEval cg(result, {});
EXPECT_EQ(cg.value<float>(), 2.f + (3.f * 3.f + 4.f));
ASSERT_EQ(cg.value<float>(), 2.f + (3.f * 3.f + 4.f));
}
void testLLVMLetTest02() {
@ -178,7 +178,7 @@ void testLLVMLetTest02() {
ExprHandle e1 = Let::make(x, ExprHandle(3.f), body);
ExprHandle e2 = Let::make(y, ExprHandle(6.f), e1);
LLVMExprEval cg(e2, {});
EXPECT_EQ(cg.value<float>(), 2.f + (3.f * 3.f + 4.f * 6.f));
ASSERT_EQ(cg.value<float>(), 2.f + (3.f * 3.f + 4.f * 6.f));
}
void testLLVMLetTestMultitype() {
@ -191,7 +191,7 @@ void testLLVMLetTestMultitype() {
ExprHandle e1 = Let::make(x, ExprHandle((uint8_t)3), body);
ExprHandle e2 = Let::make(y, ExprHandle((at::Half)6.f), e1);
LLVMExprEval cg(e2, {});
EXPECT_EQ(cg.value<double>(), 2.f + (3 * 3 + 4 * 6.f));
ASSERT_EQ(cg.value<double>(), 2.f + (3 * 3 + 4 * 6.f));
}
void testLLVMBufferTest() {
@ -201,7 +201,7 @@ void testLLVMBufferTest() {
std::vector<void*> args({v.data()});
auto rv = IntImm::make(0);
LLVMExprEval cg(rv, {a});
EXPECT_EQ(cg.value<int>(args), 0);
ASSERT_EQ(cg.value<int>(args), 0);
}
void testLLVMBlockTest() {
@ -217,9 +217,9 @@ void testLLVMBlockTest() {
});
LLVMCodeGen cg(block, {a});
EXPECT_EQ(cg.value<int>(args), 0);
EXPECT_EQ(v[0], 4);
EXPECT_EQ(v[1], 4);
ASSERT_EQ(cg.value<int>(args), 0);
ASSERT_EQ(v[0], 4);
ASSERT_EQ(v[1], 4);
}
void testLLVMLoadStoreTest() {
@ -236,9 +236,9 @@ void testLLVMLoadStoreTest() {
IntImm::make(1));
LLVMCodeGen cg(store, {a, b});
std::vector<void*> args({a_buffer.data(), b_buffer.data()});
EXPECT_EQ(cg.value<int>(args), 0);
EXPECT_EQ(a_buffer[0], 42);
EXPECT_EQ(b_buffer[0], 42);
ASSERT_EQ(cg.value<int>(args), 0);
ASSERT_EQ(a_buffer[0], 42);
ASSERT_EQ(b_buffer[0], 42);
}
void testLLVMIfThenElseTest() {
@ -260,9 +260,9 @@ void testLLVMIfThenElseTest() {
IntImm::make(1));
LLVMCodeGen cg(store, {a, b, c});
std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
EXPECT_EQ(cg.value<int>(args), 0);
EXPECT_EQ(a_buffer[0], 42);
EXPECT_EQ(b_buffer[0], 42);
ASSERT_EQ(cg.value<int>(args), 0);
ASSERT_EQ(a_buffer[0], 42);
ASSERT_EQ(b_buffer[0], 42);
}
void testLLVMVecLoadStoreTest() {
@ -279,22 +279,22 @@ void testLLVMVecLoadStoreTest() {
Broadcast::make(IntImm::make(1), 4));
LLVMCodeGen cg(store, {a, b});
std::vector<void*> args({a_buffer.data(), b_buffer.data()});
EXPECT_EQ(cg.value<int>(args), 0);
EXPECT_EQ(a_buffer[0], 1);
EXPECT_EQ(a_buffer[1], 1);
EXPECT_EQ(a_buffer[2], 1);
EXPECT_EQ(a_buffer[3], 1);
EXPECT_EQ(b_buffer[0], 1);
EXPECT_EQ(b_buffer[1], 1);
EXPECT_EQ(b_buffer[2], 1);
EXPECT_EQ(b_buffer[3], 1);
ASSERT_EQ(cg.value<int>(args), 0);
ASSERT_EQ(a_buffer[0], 1);
ASSERT_EQ(a_buffer[1], 1);
ASSERT_EQ(a_buffer[2], 1);
ASSERT_EQ(a_buffer[3], 1);
ASSERT_EQ(b_buffer[0], 1);
ASSERT_EQ(b_buffer[1], 1);
ASSERT_EQ(b_buffer[2], 1);
ASSERT_EQ(b_buffer[3], 1);
}
#define FLOAT_INTRINSICS_TEST(Name, Lanes) \
void testLLVMVecFloat_##Name##Lane##Lanes##Test() { \
void testLLVMVecFloat_##Name##Lane##Lanes##Test() { \
KernelScope kernel_scope; \
Buffer a(VarHandle("A", kHandle), kFloat, {1}); \
Buffer b(VarHandle("B", kHandle), kFloat, {1}); \
Buffer a(VarHandle("A", kHandle), kFloat, {1}); \
Buffer b(VarHandle("B", kHandle), kFloat, {1}); \
float val = 0.5f; \
std::vector<float> a_buffer(Lanes, val); \
std::vector<float> b_buffer(Lanes, val); \
@ -309,9 +309,9 @@ void testLLVMVecLoadStoreTest() {
LLVMCodeGen cg(store, {a, b}); \
std::vector<void*> args({a_buffer.data(), b_buffer.data()}); \
float ref = std::Name(0.5f); \
EXPECT_EQ(cg.value<int>(args), 0); \
ASSERT_EQ(cg.value<int>(args), 0); \
for (int i = 0; i < Lanes; i++) { \
EXPECT_FLOAT_EQ(a_buffer[i], val); \
ASSERT_FLOAT_EQ(a_buffer[i], val); \
} \
} // namespace jit
FLOAT_INTRINSICS_TEST(erf, 4)
@ -336,14 +336,14 @@ FLOAT_INTRINSICS_TEST(expm1, 8)
FLOAT_INTRINSICS_TEST(lgamma, 8)
#undef FLOAT_INTRINSICS_TEST
#define DOUBLE_INTRINSICS_TEST(Name, Lanes) \
void testLLVMVecDouble_##Name##Lane##Lanes##Test() { \
#define DOUBLE_INTRINSICS_TEST(Name, Lanes) \
void testLLVMVecDouble_##Name##Lane##Lanes##Test() { \
KernelScope kernel_scope; \
Buffer a(VarHandle("A", kHandle), kDouble, {1}); \
Buffer b(VarHandle("B", kHandle), kDouble, {1}); \
Buffer a(VarHandle("A", kHandle), kDouble, {1}); \
Buffer b(VarHandle("B", kHandle), kDouble, {1}); \
float val = 0.5f; \
std::vector<double> a_buffer(Lanes, val); \
std::vector<double> b_buffer(Lanes, val); \
std::vector<double> a_buffer(Lanes, val); \
std::vector<double> b_buffer(Lanes, val); \
auto store = Store::make( \
b, \
Ramp::make(0, 1, Lanes), \
@ -355,9 +355,9 @@ FLOAT_INTRINSICS_TEST(lgamma, 8)
LLVMCodeGen cg(store, {a, b}); \
std::vector<void*> args({a_buffer.data(), b_buffer.data()}); \
float ref = std::Name(0.5f); \
EXPECT_EQ(cg.value<int>(args), 0); \
ASSERT_EQ(cg.value<int>(args), 0); \
for (int i = 0; i < Lanes; i++) { \
EXPECT_FLOAT_EQ(a_buffer[i], val); \
ASSERT_FLOAT_EQ(a_buffer[i], val); \
} \
} // namespace jit
DOUBLE_INTRINSICS_TEST(erf, 2)
@ -395,7 +395,8 @@ void testLLVMVectorizerLoadStoreTest() {
Stmt* s = l.root_stmt();
l.vectorize(*dynamic_cast<Block*>(s)->stmts().begin());
EXPECT_TRUE(dynamic_cast<For*>(*dynamic_cast<Block*>(s)->stmts().begin()) == nullptr);
ASSERT_TRUE(
dynamic_cast<For*>(*dynamic_cast<Block*>(s)->stmts().begin()) == nullptr);
LLVMCodeGen cg(s, {a, c_buf});
@ -992,7 +993,7 @@ void testLLVMStoreFloat() {
LLVMCodeGen cg(expr, {result});
std::vector<void*> args({result_buffer.data()});
ASSERT_EQ(cg.value<int>(args), 0);
EXPECT_EQ(result_buffer[0], 3.14f);
ASSERT_EQ(result_buffer[0], 3.14f);
}
void testLLVMSimpleMath01() {
@ -1083,7 +1084,7 @@ void testLLVMBitwiseOps() {
ExprHandle f = (((a ^ (b << 1)) & c) >> 2) | d;
LLVMExprEval cg(f);
EXPECT_EQ(cg.value<int>(), 11);
ASSERT_EQ(cg.value<int>(), 11);
}
void testLLVMDynamicShapeAdd() {

View File

@ -18,11 +18,11 @@ void testConstantFoldSimple() {
ExprHandle f = (a + b);
ExprHandle newF = IRSimplifier::simplify(f);
EXPECT_NE(newF.AsNode<FloatImm>(), nullptr);
EXPECT_EQ(newF.AsNode<FloatImm>()->value(), 5);
ASSERT_NE(newF.AsNode<FloatImm>(), nullptr);
ASSERT_EQ(newF.AsNode<FloatImm>()->value(), 5);
SimpleIRExprEval eval(newF);
EXPECT_EQ(eval.value<float>(), 5.f);
ASSERT_EQ(eval.value<float>(), 5.f);
}
void testConstantFoldTwoLayer() {
@ -34,11 +34,11 @@ void testConstantFoldTwoLayer() {
ExprHandle f = (a + b) - (c + d);
ExprHandle newF = IRSimplifier::simplify(f);
EXPECT_NE(newF.AsNode<FloatImm>(), nullptr);
EXPECT_EQ(newF.AsNode<FloatImm>()->value(), -4);
ASSERT_NE(newF.AsNode<FloatImm>(), nullptr);
ASSERT_EQ(newF.AsNode<FloatImm>()->value(), -4);
SimpleIRExprEval eval(newF);
EXPECT_EQ(eval.value<float>(), -4.f);
ASSERT_EQ(eval.value<float>(), -4.f);
}
void testConstantFoldShifts() {
@ -49,11 +49,11 @@ void testConstantFoldShifts() {
ExprHandle f = ((a << b) << b) >> c;
ExprHandle newF = IRSimplifier::simplify(f);
EXPECT_NE(newF.AsNode<IntImm>(), nullptr);
EXPECT_EQ(newF.AsNode<IntImm>()->value(), 14);
ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
ASSERT_EQ(newF.AsNode<IntImm>()->value(), 14);
SimpleIRExprEval eval(newF);
EXPECT_EQ(eval.value<int>(), 7 << (4 - 3));
ASSERT_EQ(eval.value<int>(), 7 << (4 - 3));
}
void testConstantFoldBitwise() {
@ -64,11 +64,11 @@ void testConstantFoldBitwise() {
ExprHandle f = (a ^ b) & c;
ExprHandle newF = IRSimplifier::simplify(f);
EXPECT_NE(newF.AsNode<IntImm>(), nullptr);
EXPECT_EQ(newF.AsNode<IntImm>()->value(), 37);
ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
ASSERT_EQ(newF.AsNode<IntImm>()->value(), 37);
SimpleIRExprEval eval(newF);
EXPECT_EQ(eval.value<int>(), (59 ^ 22) & 101);
ASSERT_EQ(eval.value<int>(), (59 ^ 22) & 101);
}
void testConstantFoldMultiOp() {
@ -82,12 +82,12 @@ void testConstantFoldMultiOp() {
ExprHandle fn = ((a / e) - (c + d)) * (f / b);
ExprHandle newF = IRSimplifier::simplify(fn);
EXPECT_NE(newF.AsNode<FloatImm>(), nullptr);
ASSERT_NE(newF.AsNode<FloatImm>(), nullptr);
SimpleIRExprEval eval(newF);
SimpleIRExprEval ref(fn);
EXPECT_EQ(eval.value<float>(), ref.value<float>());
ASSERT_EQ(eval.value<float>(), ref.value<float>());
}
void testConstantFoldMinMax() {
@ -100,13 +100,13 @@ void testConstantFoldMinMax() {
ExprHandle minHandle = Min::make(b, c, true);
ExprHandle fn = Max::make(a, minHandle, false);
EXPECT_EQ(fn.dtype().scalar_type(), ScalarType::Float);
ASSERT_EQ(fn.dtype().scalar_type(), ScalarType::Float);
ExprHandle newF = IRSimplifier::simplify(fn);
EXPECT_NE(newF.AsNode<FloatImm>(), nullptr);
ASSERT_NE(newF.AsNode<FloatImm>(), nullptr);
SimpleIRExprEval eval(newF);
EXPECT_EQ(eval.value<float>(), 15.f);
ASSERT_EQ(eval.value<float>(), 15.f);
}
void testConstantFoldIntrinsics() {
@ -122,13 +122,13 @@ void testConstantFoldIntrinsics() {
ExprHandle fn = Intrinsics::make(kFabs, rndHandle);
ExprHandle newF = IRSimplifier::simplify(fn);
EXPECT_NE(newF.AsNode<FloatImm>(), nullptr);
EXPECT_EQ(newF.AsNode<FloatImm>()->value(), 1);
ASSERT_NE(newF.AsNode<FloatImm>(), nullptr);
ASSERT_EQ(newF.AsNode<FloatImm>()->value(), 1);
SimpleIRExprEval eval(newF);
SimpleIRExprEval ref(fn);
EXPECT_EQ(eval.value<float>(), ref.value<float>());
ASSERT_EQ(eval.value<float>(), ref.value<float>());
}
void testConstantFoldWithVar() {
@ -138,12 +138,12 @@ void testConstantFoldWithVar() {
ExprHandle newF = IRSimplifier::simplify(body);
const Mul* root = newF.AsNode<Mul>();
EXPECT_NE(root, nullptr);
EXPECT_NE(dynamic_cast<const FloatImm*>(root->rhs()), nullptr);
ASSERT_NE(root, nullptr);
ASSERT_NE(dynamic_cast<const FloatImm*>(root->rhs()), nullptr);
ExprHandle result = Let::make(x, ExprHandle(3.f), newF);
SimpleIRExprEval eval(result);
EXPECT_EQ(eval.value<float>(), 3 * (2 + 4));
ASSERT_EQ(eval.value<float>(), 3 * (2 + 4));
}
void testUnFoldableExpr() {
@ -154,14 +154,14 @@ void testUnFoldableExpr() {
ExprHandle newF = IRSimplifier::simplify(body);
const Add* root = newF.AsNode<Add>();
EXPECT_NE(root, nullptr);
EXPECT_EQ(dynamic_cast<const FloatImm*>(root->lhs()), nullptr);
EXPECT_EQ(dynamic_cast<const FloatImm*>(root->rhs()), nullptr);
ASSERT_NE(root, nullptr);
ASSERT_EQ(dynamic_cast<const FloatImm*>(root->lhs()), nullptr);
ASSERT_EQ(dynamic_cast<const FloatImm*>(root->rhs()), nullptr);
ExprHandle result = Let::make(x, ExprHandle(3.f), newF);
result = Let::make(y, ExprHandle(2.f), result);
SimpleIRExprEval eval(result);
EXPECT_EQ(eval.value<float>(), 9 + 10);
ASSERT_EQ(eval.value<float>(), 9 + 10);
}
void testHashSimple() {
@ -177,12 +177,12 @@ void testHashSimple() {
auto hash_a = hasher.hash(a.node());
auto hash_f = hasher.hash(f.node());
EXPECT_NE(hash_x, 0);
EXPECT_NE(hash_a, 0);
EXPECT_NE(hash_f, 0);
EXPECT_NE(hash_x, hash_a);
EXPECT_NE(hash_x, hash_f);
EXPECT_NE(hash_a, hash_f);
ASSERT_NE(hash_x, 0);
ASSERT_NE(hash_a, 0);
ASSERT_NE(hash_f, 0);
ASSERT_NE(hash_x, hash_a);
ASSERT_NE(hash_x, hash_f);
ASSERT_NE(hash_a, hash_f);
}
void testHashEquivalence() {
@ -192,7 +192,7 @@ void testHashEquivalence() {
ExprHandle f = (x * y) + (x * y);
const Add* root = f.AsNode<Add>();
EXPECT_NE(root, nullptr);
ASSERT_NE(root, nullptr);
HashProvider hasher;
auto hash_f = hasher.hash(f.node());
@ -200,26 +200,26 @@ void testHashEquivalence() {
auto hash_r = hasher.hash(root->rhs());
// Root not equal to either branch.
EXPECT_NE(hash_f, hash_l);
EXPECT_NE(hash_f, hash_r);
ASSERT_NE(hash_f, hash_l);
ASSERT_NE(hash_f, hash_r);
// but branches are equal.
EXPECT_EQ(hash_l, hash_r);
ASSERT_EQ(hash_l, hash_r);
// Still equivalent if separate.
ExprHandle a(2);
ExprHandle f2 = x + a / y;
ExprHandle b(2);
ExprHandle f3 = x + b / y;
EXPECT_EQ(hasher.hash(f2.node()), hasher.hash(f3.node()));
ASSERT_EQ(hasher.hash(f2.node()), hasher.hash(f3.node()));
// Not equivalent if different vars (even with same name).
VarHandle z("x", kFloat);
ExprHandle f4 = z + b / y;
EXPECT_NE(hasher.hash(f2.node()), hasher.hash(f4.node()));
ASSERT_NE(hasher.hash(f2.node()), hasher.hash(f4.node()));
// Intrinsics sanity check.
ExprHandle f5 = Intrinsics::make(kSin, x) * Intrinsics::make(kCos, x);
EXPECT_NE(hasher.hash(f5.node()), 0);
ASSERT_NE(hasher.hash(f5.node()), 0);
}
void testHashEquivalenceAfterFolding() {
@ -231,7 +231,7 @@ void testHashEquivalenceAfterFolding() {
ExprHandle f = ((a + b) * x) * (c * x);
const Mul* root = f.AsNode<Mul>();
EXPECT_NE(root, nullptr);
ASSERT_NE(root, nullptr);
HashProvider hasher;
auto hash_f = hasher.hash(f.node());
@ -239,24 +239,24 @@ void testHashEquivalenceAfterFolding() {
auto hash_r = hasher.hash(root->rhs());
// Root not equal to either branch, and branches not equal.
EXPECT_NE(hash_f, hash_l);
EXPECT_NE(hash_f, hash_r);
EXPECT_NE(hash_l, hash_r);
ASSERT_NE(hash_f, hash_l);
ASSERT_NE(hash_f, hash_r);
ASSERT_NE(hash_l, hash_r);
ExprHandle newF = IRSimplifier::simplify(f);
const Mul* newRoot = newF.AsNode<Mul>();
EXPECT_NE(newRoot, nullptr);
ASSERT_NE(newRoot, nullptr);
auto hash_f_n = hasher.hash(newF.node());
auto hash_l_n = hasher.hash(newRoot->lhs());
auto hash_r_n = hasher.hash(newRoot->rhs());
// Root not equal to either branch.
EXPECT_NE(hash_f_n, hash_l_n);
EXPECT_NE(hash_f_n, hash_r_n);
ASSERT_NE(hash_f_n, hash_l_n);
ASSERT_NE(hash_f_n, hash_r_n);
// but branches are now equal.
EXPECT_EQ(hash_l_n, hash_r_n);
ASSERT_EQ(hash_l_n, hash_r_n);
}
void testHashDifferenceTypes() {
@ -278,7 +278,7 @@ void testHashDifferenceTypes() {
// Immediates of different types are not equal.
for (unsigned int i = 0; i < immediates.size(); ++i) {
for (unsigned int j = i + 1; j < immediates.size(); ++j) {
EXPECT_NE(hasher.hash(immediates[i]), hasher.hash(immediates[j]));
ASSERT_NE(hasher.hash(immediates[i]), hasher.hash(immediates[j]));
}
}
@ -289,7 +289,7 @@ void testHashDifferenceTypes() {
ExprHandle ff1 = IRSimplifier::simplify(f1);
ExprHandle ff2 = IRSimplifier::simplify(f2);
EXPECT_EQ(hasher.hash(ff1.node()), hasher.hash(ff2.node()));
ASSERT_EQ(hasher.hash(ff1.node()), hasher.hash(ff2.node()));
}
void testHashLargeExpression() {
@ -332,15 +332,15 @@ void testHashLargeExpression() {
HashProvider hasher;
auto hash_r = hasher.hash(if_stmt);
// We should not have to do any more work.
EXPECT_TRUE(hasher.cachedHash(memcpy_stmt));
ASSERT_TRUE(hasher.cachedHash(memcpy_stmt));
auto hash_t = hasher.hash(memcpy_stmt);
EXPECT_TRUE(hasher.cachedHash(store_ramp_stmt));
ASSERT_TRUE(hasher.cachedHash(store_ramp_stmt));
auto hash_f = hasher.hash(store_ramp_stmt);
// Root not equal to either branch, and branches not equal.
EXPECT_NE(hash_r, hash_t);
EXPECT_NE(hash_r, hash_f);
EXPECT_NE(hash_t, hash_f);
ASSERT_NE(hash_r, hash_t);
ASSERT_NE(hash_r, hash_f);
ASSERT_NE(hash_t, hash_f);
}
/// (2.f + x) + 4.f => x + 6.f
@ -351,13 +351,13 @@ void testSimplifyAdd() {
ExprHandle simplified = IRSimplifier::simplify(body);
const Add* root = simplified.AsNode<Add>();
EXPECT_NE(root, nullptr);
ASSERT_NE(root, nullptr);
const Var* lhs = dynamic_cast<const Var*>(root->lhs());
EXPECT_NE(lhs, nullptr);
EXPECT_EQ(lhs->name_hint(), "x");
ASSERT_NE(lhs, nullptr);
ASSERT_EQ(lhs->name_hint(), "x");
const FloatImm* rhs = dynamic_cast<const FloatImm*>(root->rhs());
EXPECT_NE(rhs, nullptr);
EXPECT_EQ(rhs->value(), 6.f);
ASSERT_NE(rhs, nullptr);
ASSERT_EQ(rhs->value(), 6.f);
}
/// (2.f - x) - 4.f => -2.f - x
@ -368,13 +368,13 @@ void testSimplifySub() {
ExprHandle simplified = IRSimplifier::simplify(body);
const Sub* root = simplified.AsNode<Sub>();
EXPECT_NE(root, nullptr);
ASSERT_NE(root, nullptr);
const FloatImm* lhs = dynamic_cast<const FloatImm*>(root->lhs());
EXPECT_NE(lhs, nullptr);
EXPECT_EQ(lhs->value(), -2.f);
ASSERT_NE(lhs, nullptr);
ASSERT_EQ(lhs->value(), -2.f);
const Var* rhs = dynamic_cast<const Var*>(root->rhs());
EXPECT_NE(rhs, nullptr);
EXPECT_EQ(rhs->name_hint(), "x");
ASSERT_NE(rhs, nullptr);
ASSERT_EQ(rhs->name_hint(), "x");
}
/// 2.f * (1.f - x) - 4.f => -6.f - (x * 2.f)
@ -385,18 +385,18 @@ void testSimplifyMultiLayer() {
ExprHandle simplified = IRSimplifier::simplify(body);
const Sub* root = simplified.AsNode<Sub>();
EXPECT_NE(root, nullptr);
ASSERT_NE(root, nullptr);
const FloatImm* lhs = dynamic_cast<const FloatImm*>(root->lhs());
EXPECT_NE(lhs, nullptr);
EXPECT_EQ(lhs->value(), -6.f);
ASSERT_NE(lhs, nullptr);
ASSERT_EQ(lhs->value(), -6.f);
const Mul* rhs = dynamic_cast<const Mul*>(root->rhs());
EXPECT_NE(rhs, nullptr);
ASSERT_NE(rhs, nullptr);
const Var* varX = dynamic_cast<const Var*>(rhs->lhs());
EXPECT_NE(varX, nullptr);
EXPECT_EQ(varX->name_hint(), "x");
ASSERT_NE(varX, nullptr);
ASSERT_EQ(varX->name_hint(), "x");
const FloatImm* mulRhs = dynamic_cast<const FloatImm*>(rhs->rhs());
EXPECT_NE(mulRhs, nullptr);
EXPECT_EQ(mulRhs->value(), 2.f);
ASSERT_NE(mulRhs, nullptr);
ASSERT_EQ(mulRhs->value(), 2.f);
}
/// 2 * (3 * x) - (x * 4) => x * 2
@ -408,13 +408,13 @@ void testSimplifyMultiTerm() {
ExprHandle simplified = IRSimplifier::simplify(body);
const Mul* root = simplified.AsNode<Mul>();
EXPECT_NE(root, nullptr);
ASSERT_NE(root, nullptr);
const Var* lhs = dynamic_cast<const Var*>(root->lhs());
EXPECT_NE(lhs, nullptr);
EXPECT_EQ(lhs->name_hint(), "x");
ASSERT_NE(lhs, nullptr);
ASSERT_EQ(lhs->name_hint(), "x");
const IntImm* rhs = dynamic_cast<const IntImm*>(root->rhs());
EXPECT_NE(rhs, nullptr);
EXPECT_EQ(rhs->value(), 2);
ASSERT_NE(rhs, nullptr);
ASSERT_EQ(rhs->value(), 2);
}
/// 2 * (3 * (f)x) - (x * 4) => x * 2.f
@ -426,13 +426,13 @@ void testSimplifyCasts() {
ExprHandle simplified = IRSimplifier::simplify(body);
const Mul* root = simplified.AsNode<Mul>();
EXPECT_NE(root, nullptr);
ASSERT_NE(root, nullptr);
const Var* lhs = dynamic_cast<const Var*>(root->lhs());
EXPECT_NE(lhs, nullptr);
EXPECT_EQ(lhs->name_hint(), "x");
ASSERT_NE(lhs, nullptr);
ASSERT_EQ(lhs->name_hint(), "x");
const FloatImm* rhs = dynamic_cast<const FloatImm*>(root->rhs());
EXPECT_NE(rhs, nullptr);
EXPECT_EQ(rhs->value(), 2);
ASSERT_NE(rhs, nullptr);
ASSERT_EQ(rhs->value(), 2);
}
/// (x + 0) * 1 => x
@ -443,8 +443,8 @@ void testSimplifyEliminatesNoOps() {
ExprHandle simplified = IRSimplifier::simplify(body);
const Var* root = simplified.AsNode<Var>();
EXPECT_NE(root, nullptr);
EXPECT_EQ(root->name_hint(), "x");
ASSERT_NE(root, nullptr);
ASSERT_EQ(root->name_hint(), "x");
}
/// Cannot simplify this.
@ -456,16 +456,16 @@ void testSimplifyMultiVar() {
ExprHandle simplified = IRSimplifier::simplify(body);
const Add* root = simplified.AsNode<Add>();
EXPECT_NE(root, nullptr);
ASSERT_NE(root, nullptr);
const Mul* lhs = dynamic_cast<const Mul*>(root->lhs());
EXPECT_NE(lhs, nullptr);
ASSERT_NE(lhs, nullptr);
const Var* varY = dynamic_cast<const Var*>(lhs->lhs());
EXPECT_EQ(varY->name_hint(), "y");
ASSERT_EQ(varY->name_hint(), "y");
const Mul* rhs = dynamic_cast<const Mul*>(root->rhs());
EXPECT_NE(rhs, nullptr);
ASSERT_NE(rhs, nullptr);
const Var* varX = dynamic_cast<const Var*>(rhs->lhs());
EXPECT_NE(varX, nullptr);
EXPECT_EQ(varX->name_hint(), "x");
ASSERT_NE(varX, nullptr);
ASSERT_EQ(varX->name_hint(), "x");
}
/// y + x * 0 => y
@ -477,8 +477,8 @@ void testSimplifyEliminatesVar() {
ExprHandle simplified = IRSimplifier::simplify(body);
const Var* root = simplified.AsNode<Var>();
EXPECT_NE(root, nullptr);
EXPECT_EQ(root->name_hint(), "y");
ASSERT_NE(root, nullptr);
ASSERT_EQ(root->name_hint(), "y");
}
} // namespace jit

View File

@ -10,34 +10,34 @@ void testTypeTest01() {
KernelScope kernel_scope;
{
Dtype dt1 = kInt;
EXPECT_EQ(dt1, kInt);
ASSERT_EQ(dt1, kInt);
}
{
Dtype dt2_a(kInt, 8);
Dtype dt2_b(kInt, 4);
Dtype dt2_c(ScalarType::Int, 8);
EXPECT_EQ(dt2_a, dt2_c);
EXPECT_NE(dt2_a, dt2_b);
ASSERT_EQ(dt2_a, dt2_c);
ASSERT_NE(dt2_a, dt2_b);
}
{
EXPECT_EQ(kInt, ToDtype<int>());
EXPECT_EQ(kFloat, ToDtype<float>());
EXPECT_EQ(kByte, ToDtype<uint8_t>());
EXPECT_EQ(kChar, ToDtype<int8_t>());
EXPECT_EQ(kShort, ToDtype<int16_t>());
EXPECT_EQ(kLong, ToDtype<int64_t>());
EXPECT_EQ(kHalf, ToDtype<at::Half>());
EXPECT_EQ(kDouble, ToDtype<double>());
EXPECT_EQ(kBool, ToDtype<bool>());
ASSERT_EQ(kInt, ToDtype<int>());
ASSERT_EQ(kFloat, ToDtype<float>());
ASSERT_EQ(kByte, ToDtype<uint8_t>());
ASSERT_EQ(kChar, ToDtype<int8_t>());
ASSERT_EQ(kShort, ToDtype<int16_t>());
ASSERT_EQ(kLong, ToDtype<int64_t>());
ASSERT_EQ(kHalf, ToDtype<at::Half>());
ASSERT_EQ(kDouble, ToDtype<double>());
ASSERT_EQ(kBool, ToDtype<bool>());
}
{
Dtype int32x8(kInt, 8);
Dtype float32x8(kFloat, 8);
EXPECT_NE(int32x8, float32x8);
EXPECT_EQ(float32x8, BinaryOpDtype(int32x8, float32x8));
EXPECT_EQ(float32x8, BinaryOpDtype(float32x8, int32x8));
EXPECT_EQ(int32x8, BinaryOpDtype(int32x8, int32x8));
EXPECT_EQ(float32x8, BinaryOpDtype(float32x8, float32x8));
ASSERT_NE(int32x8, float32x8);
ASSERT_EQ(float32x8, BinaryOpDtype(int32x8, float32x8));
ASSERT_EQ(float32x8, BinaryOpDtype(float32x8, int32x8));
ASSERT_EQ(int32x8, BinaryOpDtype(int32x8, int32x8));
ASSERT_EQ(float32x8, BinaryOpDtype(float32x8, float32x8));
}
}
@ -51,7 +51,7 @@ void testTypePropagation() {
(x * FloatImm::make(3.f) + FloatImm::make(4.f) * y);
ExprHandle e1 = Let::make(x, FloatImm::make(3.f), body);
ExprHandle e2 = Let::make(y, FloatImm::make(6.f), e1);
EXPECT_EQ(e2.dtype(), kFloat);
ASSERT_EQ(e2.dtype(), kFloat);
}
// Int to bigger int:
{
@ -62,7 +62,7 @@ void testTypePropagation() {
ShortImm::make(2.f) + (x * ShortImm::make(3) + ShortImm::make(4) * y);
ExprHandle e1 = Let::make(x, ShortImm::make(3), body);
ExprHandle e2 = Let::make(y, LongImm::make(6), e1);
EXPECT_EQ(e2.dtype(), kLong);
ASSERT_EQ(e2.dtype(), kLong);
}
// Float to bigger float:
{
@ -73,7 +73,7 @@ void testTypePropagation() {
HalfImm::make(2.f) + (x * HalfImm::make(3) + HalfImm::make(4) * y);
ExprHandle e1 = Let::make(x, HalfImm::make(3), body);
ExprHandle e2 = Let::make(y, DoubleImm::make(6), e1);
EXPECT_EQ(e2.dtype(), kDouble);
ASSERT_EQ(e2.dtype(), kDouble);
}
// Int to Float:
{
@ -84,7 +84,7 @@ void testTypePropagation() {
IntImm::make(2) + (x * IntImm::make(3) + IntImm::make(4) * y);
ExprHandle e1 = Let::make(x, FloatImm::make(3.f), body);
ExprHandle e2 = Let::make(y, IntImm::make(6), e1);
EXPECT_EQ(e2.dtype(), kFloat);
ASSERT_EQ(e2.dtype(), kFloat);
}
// Smaller float, bigger Int:
{
@ -95,7 +95,7 @@ void testTypePropagation() {
HalfImm::make(2) + (x * HalfImm::make(3) + HalfImm::make(4) * y);
ExprHandle e1 = Let::make(x, HalfImm::make(3), body);
ExprHandle e2 = Let::make(y, LongImm::make(6), e1);
EXPECT_EQ(e2.dtype(), kHalf);
ASSERT_EQ(e2.dtype(), kHalf);
}
// Bigger float, smaller Int:
{
@ -106,7 +106,7 @@ void testTypePropagation() {
CharImm::make(2) + (x * CharImm::make(3) + CharImm::make(4) * y);
ExprHandle e1 = Let::make(x, CharImm::make(3), body);
ExprHandle e2 = Let::make(y, DoubleImm::make(6), e1);
EXPECT_EQ(e2.dtype(), kDouble);
ASSERT_EQ(e2.dtype(), kDouble);
}
// Sign change char/byte upgrades to short:
{
@ -117,7 +117,7 @@ void testTypePropagation() {
CharImm::make(2) + (x * CharImm::make(3) + CharImm::make(4) * y);
ExprHandle e1 = Let::make(x, CharImm::make(3), body);
ExprHandle e2 = Let::make(y, ByteImm::make(6), e1);
EXPECT_EQ(e2.dtype(), kShort);
ASSERT_EQ(e2.dtype(), kShort);
}
}
} // namespace jit

View File

@ -9,240 +9,240 @@
namespace torch {
namespace jit {
#define TH_FORALL_TESTS(_) \
_(ExprBasicValueTest) \
_(ExprBasicValueTest02) \
_(ExprLetTest01) \
_(ExprLetStmtTest01) \
_(ExprLetTest02) \
_(ExprIntTest) \
_(ExprFloatTest) \
_(ExprByteTest) \
_(ExprCharTest) \
_(ExprShortTest) \
_(ExprLongTest) \
_(ExprHalfTest) \
_(ExprDoubleTest) \
_(ExprVectorAdd01) \
_(ExprCompareSelectEQ) \
_(ExprSubstitute01) \
_(ExprMath01) \
_(ExprUnaryMath01) \
_(ExprBinaryMath01) \
_(ExprDynamicShapeAdd) \
_(ExprBitwiseOps) \
_(IRPrinterBasicValueTest) \
_(IRPrinterBasicValueTest02) \
_(IRPrinterLetTest01) \
_(IRPrinterLetTest02) \
_(IRPrinterCastTest) \
_(ExprSimple01) \
_(ExprLower01) \
_(ExprSimple02) \
_(ExprSplitWithTailNone) \
_(ExprSplitWithMask01) \
_(ScheduleBroadcastAddBuffer) \
_(ScheduleFunctionCall01) \
_(ScheduleInlineFunc01) \
_(ScheduleFuserStyle) \
_(ScheduleFuserThreeArg) \
_(ScheduleDynamicShape2D) \
_(TypeTest01) \
_(TypePropagation) \
_(Cond01) \
_(IfThenElse01) \
_(IfThenElse02) \
_(ATen_cast_Float) \
_(ATennegInt) \
_(ATennegFloat) \
_(ATenaddInt) \
_(ATenaddFloat) \
_(ATensubInt) \
_(ATensubFloat) \
_(ATenlerp) \
_(ATenaddcmulInt) \
_(ATenaddcmulFloat) \
_(ATenmulInt) \
_(ATenmulFloat) \
_(ATendivInt) \
_(ATendivFloat) \
_(ATenmaxInt) \
_(ATenmaxFloat) \
_(ATenminInt) \
_(ATenminFloat) \
_(ATen_sigmoid_backward) \
_(ATen_tanh_backward) \
_(ATenreciprocal) \
_(ATenreluInt) \
_(ATenreluFloat) \
_(ATenlogFloat) \
_(ATenlog10Float) \
_(ATenlog2Float) \
_(ATenexpFloat) \
_(ATenerfFloat) \
_(ATencosFloat) \
_(ATeneqInt) \
_(ATengeInt) \
_(ATengtInt) \
_(ATenleInt) \
_(ATenltInt) \
_(ConstantFoldSimple) \
_(ConstantFoldTwoLayer) \
_(ConstantFoldShifts) \
_(ConstantFoldBitwise) \
_(ConstantFoldMultiOp) \
_(ConstantFoldMinMax) \
_(ConstantFoldIntrinsics) \
_(ConstantFoldWithVar) \
_(UnFoldableExpr) \
_(HashSimple) \
_(HashEquivalence) \
_(HashEquivalenceAfterFolding) \
_(HashDifferenceTypes) \
_(HashLargeExpression) \
_(SimplifyAdd) \
_(SimplifySub) \
_(SimplifyMultiLayer) \
_(SimplifyMultiTerm) \
_(SimplifyCasts) \
_(SimplifyEliminatesNoOps) \
_(SimplifyMultiVar) \
_(SimplifyEliminatesVar) \
#define TH_FORALL_TENSOREXPR_TESTS(_) \
_(ExprBasicValueTest) \
_(ExprBasicValueTest02) \
_(ExprLetTest01) \
_(ExprLetStmtTest01) \
_(ExprLetTest02) \
_(ExprIntTest) \
_(ExprFloatTest) \
_(ExprByteTest) \
_(ExprCharTest) \
_(ExprShortTest) \
_(ExprLongTest) \
_(ExprHalfTest) \
_(ExprDoubleTest) \
_(ExprVectorAdd01) \
_(ExprCompareSelectEQ) \
_(ExprSubstitute01) \
_(ExprMath01) \
_(ExprUnaryMath01) \
_(ExprBinaryMath01) \
_(ExprDynamicShapeAdd) \
_(ExprBitwiseOps) \
_(IRPrinterBasicValueTest) \
_(IRPrinterBasicValueTest02) \
_(IRPrinterLetTest01) \
_(IRPrinterLetTest02) \
_(IRPrinterCastTest) \
_(ExprSimple01) \
_(ExprLower01) \
_(ExprSimple02) \
_(ExprSplitWithTailNone) \
_(ExprSplitWithMask01) \
_(ScheduleBroadcastAddBuffer) \
_(ScheduleFunctionCall01) \
_(ScheduleInlineFunc01) \
_(ScheduleFuserStyle) \
_(ScheduleFuserThreeArg) \
_(ScheduleDynamicShape2D) \
_(TypeTest01) \
_(TypePropagation) \
_(Cond01) \
_(IfThenElse01) \
_(IfThenElse02) \
_(ATen_cast_Float) \
_(ATennegInt) \
_(ATennegFloat) \
_(ATenaddInt) \
_(ATenaddFloat) \
_(ATensubInt) \
_(ATensubFloat) \
_(ATenlerp) \
_(ATenaddcmulInt) \
_(ATenaddcmulFloat) \
_(ATenmulInt) \
_(ATenmulFloat) \
_(ATendivInt) \
_(ATendivFloat) \
_(ATenmaxInt) \
_(ATenmaxFloat) \
_(ATenminInt) \
_(ATenminFloat) \
_(ATen_sigmoid_backward) \
_(ATen_tanh_backward) \
_(ATenreciprocal) \
_(ATenreluInt) \
_(ATenreluFloat) \
_(ATenlogFloat) \
_(ATenlog10Float) \
_(ATenlog2Float) \
_(ATenexpFloat) \
_(ATenerfFloat) \
_(ATencosFloat) \
_(ATeneqInt) \
_(ATengeInt) \
_(ATengtInt) \
_(ATenleInt) \
_(ATenltInt) \
_(ConstantFoldSimple) \
_(ConstantFoldTwoLayer) \
_(ConstantFoldShifts) \
_(ConstantFoldBitwise) \
_(ConstantFoldMultiOp) \
_(ConstantFoldMinMax) \
_(ConstantFoldIntrinsics) \
_(ConstantFoldWithVar) \
_(UnFoldableExpr) \
_(HashSimple) \
_(HashEquivalence) \
_(HashEquivalenceAfterFolding) \
_(HashDifferenceTypes) \
_(HashLargeExpression) \
_(SimplifyAdd) \
_(SimplifySub) \
_(SimplifyMultiLayer) \
_(SimplifyMultiTerm) \
_(SimplifyCasts) \
_(SimplifyEliminatesNoOps) \
_(SimplifyMultiVar) \
_(SimplifyEliminatesVar) \
_(StmtClone)
#define TH_FORALL_TESTS_LLVM(_) \
_(LLVMByteImmTest) \
_(LLVMCharImmTest) \
_(LLVMShortImmTest) \
_(LLVMIntImmTest) \
_(LLVMLongImmTest) \
_(LLVMFloatImmTest) \
_(LLVMDoubleImmTest) \
_(LLVMHalfImmTest) \
_(LLVMByteAddTest) \
_(LLVMCharAddTest) \
_(LLVMShortAddTest) \
_(LLVMIntAddTest) \
_(LLVMLongAddTest) \
_(LLVMFloatAddTest) \
_(LLVMDoubleAddTest) \
_(LLVMHalfAddTest) \
_(LLVMByteSubTest) \
_(LLVMCharSubTest) \
_(LLVMShortSubTest) \
_(LLVMIntSubTest) \
_(LLVMLongSubTest) \
_(LLVMFloatSubTest) \
_(LLVMDoubleSubTest) \
_(LLVMHalfSubTest) \
_(LLVMByteMulTest) \
_(LLVMCharMulTest) \
_(LLVMShortMulTest) \
_(LLVMIntMulTest) \
_(LLVMLongMulTest) \
_(LLVMFloatMulTest) \
_(LLVMDoubleMulTest) \
_(LLVMHalfMulTest) \
_(LLVMByteDivTest) \
_(LLVMCharDivTest) \
_(LLVMShortDivTest) \
_(LLVMIntDivTest) \
_(LLVMLongDivTest) \
_(LLVMFloatDivTest) \
_(LLVMDoubleDivTest) \
_(LLVMHalfDivTest) \
_(LLVMIntToFloatCastTest) \
_(LLVMFloatToIntCastTest) \
_(LLVMIntToLongCastTest) \
_(LLVMByteToCharCastTest) \
_(LLVMHalfToLongCastTest) \
_(LLVMByteToDoubleCastTest) \
_(LLVMLetTest01) \
_(LLVMLetTest02) \
_(LLVMLetTestMultitype) \
_(LLVMBufferTest) \
_(LLVMBlockTest) \
_(LLVMLoadStoreTest) \
_(LLVMVecLoadStoreTest) \
_(LLVMVecFloat_acosLane4Test) \
_(LLVMVecFloat_asinLane4Test) \
_(LLVMVecFloat_atanLane4Test) \
_(LLVMVecFloat_coshLane4Test) \
_(LLVMVecFloat_sinhLane4Test) \
_(LLVMVecFloat_tanhLane4Test) \
_(LLVMVecFloat_erfLane4Test) \
_(LLVMVecFloat_erfcLane4Test) \
_(LLVMVecFloat_expm1Lane4Test) \
_(LLVMVecFloat_lgammaLane4Test) \
_(LLVMVecFloat_acosLane8Test) \
_(LLVMVecFloat_asinLane8Test) \
_(LLVMVecFloat_atanLane8Test) \
_(LLVMVecFloat_coshLane8Test) \
_(LLVMVecFloat_sinhLane8Test) \
_(LLVMVecFloat_tanhLane8Test) \
_(LLVMVecFloat_erfLane8Test) \
_(LLVMVecFloat_erfcLane8Test) \
_(LLVMVecFloat_expm1Lane8Test) \
_(LLVMVecFloat_lgammaLane8Test) \
_(LLVMVecDouble_acosLane2Test) \
_(LLVMVecDouble_asinLane2Test) \
_(LLVMVecDouble_atanLane2Test) \
_(LLVMVecDouble_coshLane2Test) \
_(LLVMVecDouble_sinhLane2Test) \
_(LLVMVecDouble_tanhLane2Test) \
_(LLVMVecDouble_erfLane2Test) \
_(LLVMVecDouble_erfcLane2Test) \
_(LLVMVecDouble_expm1Lane2Test) \
_(LLVMVecDouble_lgammaLane2Test) \
_(LLVMVecDouble_acosLane4Test) \
_(LLVMVecDouble_asinLane4Test) \
_(LLVMVecDouble_atanLane4Test) \
_(LLVMVecDouble_coshLane4Test) \
_(LLVMVecDouble_sinhLane4Test) \
_(LLVMVecDouble_tanhLane4Test) \
_(LLVMVecDouble_erfLane4Test) \
_(LLVMVecDouble_erfcLane4Test) \
_(LLVMVecDouble_expm1Lane4Test) \
_(LLVMVecDouble_lgammaLane4Test) \
_(LLVMMemcpyTest) \
_(LLVMBzeroTest) \
_(LLVMElemwiseAdd) \
_(LLVMElemwiseAddFloat) \
_(LLVMElemwiseLog10Float) \
_(LLVMElemwiseMaxInt) \
_(LLVMElemwiseMinInt) \
_(LLVMElemwiseMaxNumFloat) \
_(LLVMElemwiseMaxNumNaNFloat) \
_(LLVMElemwiseMinNumFloat) \
_(LLVMElemwiseMinNumNaNFloat) \
_(LLVMCompareSelectIntEQ) \
_(LLVMCompareSelectFloatEQ) \
_(LLVMStoreFloat) \
_(LLVMSimpleMath01) \
_(LLVMComputeMul) \
_(LLVMBroadcastAdd) \
_(LLVMBitwiseOps) \
_(LLVMDynamicShapeAdd) \
_(LLVMBindDynamicShapeAdd) \
_(LLVMTensorDynamicShapeAdd) \
_(LLVMDynamicShape2D) \
_(LLVMIfThenElseTest) \
#define TH_FORALL_TENSOREXPR_TESTS_LLVM(_) \
_(LLVMByteImmTest) \
_(LLVMCharImmTest) \
_(LLVMShortImmTest) \
_(LLVMIntImmTest) \
_(LLVMLongImmTest) \
_(LLVMFloatImmTest) \
_(LLVMDoubleImmTest) \
_(LLVMHalfImmTest) \
_(LLVMByteAddTest) \
_(LLVMCharAddTest) \
_(LLVMShortAddTest) \
_(LLVMIntAddTest) \
_(LLVMLongAddTest) \
_(LLVMFloatAddTest) \
_(LLVMDoubleAddTest) \
_(LLVMHalfAddTest) \
_(LLVMByteSubTest) \
_(LLVMCharSubTest) \
_(LLVMShortSubTest) \
_(LLVMIntSubTest) \
_(LLVMLongSubTest) \
_(LLVMFloatSubTest) \
_(LLVMDoubleSubTest) \
_(LLVMHalfSubTest) \
_(LLVMByteMulTest) \
_(LLVMCharMulTest) \
_(LLVMShortMulTest) \
_(LLVMIntMulTest) \
_(LLVMLongMulTest) \
_(LLVMFloatMulTest) \
_(LLVMDoubleMulTest) \
_(LLVMHalfMulTest) \
_(LLVMByteDivTest) \
_(LLVMCharDivTest) \
_(LLVMShortDivTest) \
_(LLVMIntDivTest) \
_(LLVMLongDivTest) \
_(LLVMFloatDivTest) \
_(LLVMDoubleDivTest) \
_(LLVMHalfDivTest) \
_(LLVMIntToFloatCastTest) \
_(LLVMFloatToIntCastTest) \
_(LLVMIntToLongCastTest) \
_(LLVMByteToCharCastTest) \
_(LLVMHalfToLongCastTest) \
_(LLVMByteToDoubleCastTest) \
_(LLVMLetTest01) \
_(LLVMLetTest02) \
_(LLVMLetTestMultitype) \
_(LLVMBufferTest) \
_(LLVMBlockTest) \
_(LLVMLoadStoreTest) \
_(LLVMVecLoadStoreTest) \
_(LLVMVecFloat_acosLane4Test) \
_(LLVMVecFloat_asinLane4Test) \
_(LLVMVecFloat_atanLane4Test) \
_(LLVMVecFloat_coshLane4Test) \
_(LLVMVecFloat_sinhLane4Test) \
_(LLVMVecFloat_tanhLane4Test) \
_(LLVMVecFloat_erfLane4Test) \
_(LLVMVecFloat_erfcLane4Test) \
_(LLVMVecFloat_expm1Lane4Test) \
_(LLVMVecFloat_lgammaLane4Test) \
_(LLVMVecFloat_acosLane8Test) \
_(LLVMVecFloat_asinLane8Test) \
_(LLVMVecFloat_atanLane8Test) \
_(LLVMVecFloat_coshLane8Test) \
_(LLVMVecFloat_sinhLane8Test) \
_(LLVMVecFloat_tanhLane8Test) \
_(LLVMVecFloat_erfLane8Test) \
_(LLVMVecFloat_erfcLane8Test) \
_(LLVMVecFloat_expm1Lane8Test) \
_(LLVMVecFloat_lgammaLane8Test) \
_(LLVMVecDouble_acosLane2Test) \
_(LLVMVecDouble_asinLane2Test) \
_(LLVMVecDouble_atanLane2Test) \
_(LLVMVecDouble_coshLane2Test) \
_(LLVMVecDouble_sinhLane2Test) \
_(LLVMVecDouble_tanhLane2Test) \
_(LLVMVecDouble_erfLane2Test) \
_(LLVMVecDouble_erfcLane2Test) \
_(LLVMVecDouble_expm1Lane2Test) \
_(LLVMVecDouble_lgammaLane2Test) \
_(LLVMVecDouble_acosLane4Test) \
_(LLVMVecDouble_asinLane4Test) \
_(LLVMVecDouble_atanLane4Test) \
_(LLVMVecDouble_coshLane4Test) \
_(LLVMVecDouble_sinhLane4Test) \
_(LLVMVecDouble_tanhLane4Test) \
_(LLVMVecDouble_erfLane4Test) \
_(LLVMVecDouble_erfcLane4Test) \
_(LLVMVecDouble_expm1Lane4Test) \
_(LLVMVecDouble_lgammaLane4Test) \
_(LLVMMemcpyTest) \
_(LLVMBzeroTest) \
_(LLVMElemwiseAdd) \
_(LLVMElemwiseAddFloat) \
_(LLVMElemwiseLog10Float) \
_(LLVMElemwiseMaxInt) \
_(LLVMElemwiseMinInt) \
_(LLVMElemwiseMaxNumFloat) \
_(LLVMElemwiseMaxNumNaNFloat) \
_(LLVMElemwiseMinNumFloat) \
_(LLVMElemwiseMinNumNaNFloat) \
_(LLVMCompareSelectIntEQ) \
_(LLVMCompareSelectFloatEQ) \
_(LLVMStoreFloat) \
_(LLVMSimpleMath01) \
_(LLVMComputeMul) \
_(LLVMBroadcastAdd) \
_(LLVMBitwiseOps) \
_(LLVMDynamicShapeAdd) \
_(LLVMBindDynamicShapeAdd) \
_(LLVMTensorDynamicShapeAdd) \
_(LLVMDynamicShape2D) \
_(LLVMIfThenElseTest) \
_(LLVMVectorizerLoadStoreTest)
#define TH_FORALL_TESTS_CUDA(_) \
_(CudaTestVectorAdd01) \
_(CudaTestVectorAdd02) \
_(CudaDynamicShape2D) \
_(CudaTestRand01) \
#define TH_FORALL_TENSOREXPR_TESTS_CUDA(_) \
_(CudaTestVectorAdd01) \
_(CudaTestVectorAdd02) \
_(CudaDynamicShape2D) \
_(CudaTestRand01) \
_(CudaDynamicShapeSplit)
#define DECLARE_TENSOREXPR_TEST(name) void test##name();
TH_FORALL_TESTS(DECLARE_TENSOREXPR_TEST)
TH_FORALL_TENSOREXPR_TESTS(DECLARE_TENSOREXPR_TEST)
#ifdef TORCH_ENABLE_LLVM
TH_FORALL_TESTS_LLVM(DECLARE_TENSOREXPR_TEST)
TH_FORALL_TENSOREXPR_TESTS_LLVM(DECLARE_TENSOREXPR_TEST)
#endif
#ifdef USE_CUDA
TH_FORALL_TESTS_CUDA(DECLARE_TENSOREXPR_TEST)
TH_FORALL_TENSOREXPR_TESTS_CUDA(DECLARE_TENSOREXPR_TEST)
#endif
#undef DECLARE_TENSOREXPR_TEST

View File

@ -3066,6 +3066,22 @@ graph(%Ra, %Rb):
torch._C._jit_run_cpp_tests(run_cuda=True)
tests_setup.shutdown()
@unittest.skipIf(IS_SANDCASTLE, "gtest runs these in sandcastle")
@unittest.skipIf(RUN_CUDA, "covered by test_tensorexpr_cuda")
@unittest.skipIf(IS_WINDOWS, "enable on windows")
@unittest.skipIf(not torch._C._has_tensorexpr_cpp_tests(), "Tests were not built, use BUILD_TEST=1")
@skipIfRocm
def test_tensorexpr_cpp(self):
torch._C._run_tensorexpr_cpp_tests(run_cuda=False)
@unittest.skipIf(IS_SANDCASTLE, "gtest runs these in sandcastle")
@unittest.skipIf(not RUN_CUDA, "covered by test_tensorexpr")
@unittest.skipIf(IS_WINDOWS, "enable on windows")
@unittest.skipIf(not torch._C._has_tensorexpr_cpp_tests(), "Tests were not built, use BUILD_TEST=1")
@skipIfRocm
def test_tensorexpr_cpp_cuda(self):
torch._C._run_tensorexpr_cpp_tests(run_cuda=True)
def test_batchnorm(self):
x = torch.ones(2, 2, 2, 2)
g, outputs, inputs = torch.jit._get_trace_graph(nn.BatchNorm2d(2), x,

View File

@ -401,9 +401,11 @@ def add_torch_libs():
"torch/csrc/utils/tensor_numpy.cpp",
"torch/csrc/utils/tensor_types.cpp",
"test/cpp/jit/torch_python_test.cpp",
"test/cpp/tensorexpr/padded_buffer.cpp",
]
libtorch_python_sources.extend(native.glob(["test/cpp/jit/test_*.cpp"]))
libtorch_python_sources.extend(native.glob(["test/cpp/tensorexpr/test_*.cpp"]))
compiler_flags_cpu = [
"-DUSE_C10D",
@ -433,7 +435,7 @@ def add_torch_libs():
"-Wno-unknown-pragmas",
],
},
"headers": native.glob(["torch/csrc/**/*.h", "torch/csrc/generic/*.cpp", "test/cpp/jit/*.h"]),
"headers": native.glob(["torch/csrc/**/*.h", "torch/csrc/generic/*.cpp", "test/cpp/jit/*.h", "test/cpp/tensorexpr/*.h"]),
}
propagated_pp_flags = [
"-Icaffe2",

View File

@ -118,6 +118,7 @@ if(BUILD_TEST AND NOT USE_ROCM)
add_definitions(-DBUILDING_TESTS)
list(APPEND TORCH_PYTHON_SRCS
${TORCH_ROOT}/test/cpp/jit/torch_python_test.cpp
${TENSOREXPR_TEST_SRCS_WITH_PADDED}
${JIT_TEST_SRCS}
)
endif()

View File

@ -103,6 +103,7 @@ bool loadPythonClasses() {
#if !defined(__HIP_PLATFORM_HCC__)
TORCH_API void runJITCPPTests(bool runCuda);
TORCH_API void runTENSOREXPRCPPTests(bool runCuda);
#endif
void initJITBindings(PyObject* module) {
@ -328,9 +329,23 @@ void initJITBindings(PyObject* module) {
},
py::arg("run_cuda"))
.def("_jit_has_cpp_tests", []() { return true; })
.def(
"_run_tensorexpr_cpp_tests",
[](bool runCuda) {
// We have to release the GIL inside this method, because if we
// happen to initialize the autograd engine in these tests, the
// newly spawned worker threads will try to initialize their
// PyThreadState*, and they need the GIL for this.
pybind11::gil_scoped_release _no_gil;
return runTENSOREXPRCPPTests(runCuda);
},
py::arg("run_cuda"))
.def("_has_tensorexpr_cpp_tests", []() { return true; })
#else
.def("_jit_run_cpp_tests", []() { throw std::exception(); })
.def("_jit_has_cpp_tests", []() { return false; })
.def("_run_tensorexpr_cpp_tests", []() { throw std::exception(); })
.def("_has_tensorexpr_cpp_tests", []() { return false; })
#endif
.def(
"_jit_flatten",

View File

@ -4,6 +4,7 @@
#include <unordered_map>
#include <vector>
#include <c10/macros/Macros.h>
#include <c10/util/Logging.h>
#include <torch/csrc/jit/tensorexpr/buffer.h>
#include <torch/csrc/jit/tensorexpr/codegen.h>
@ -98,6 +99,26 @@ inline bool mod_value(bool lhs, bool rhs) {
throw std::runtime_error("Attempted modulus of bool");
}
template <typename T>
inline typename std::enable_if<std::is_integral<T>::value, T>::type div_value(
T lhs,
T rhs) {
TORCH_CHECK(rhs != 0, "Division by zero");
return lhs / rhs;
}
template <typename T>
inline typename std::enable_if<std::is_floating_point<T>::value, T>::
type __ubsan_ignore_float_divide_by_zero__
div_value(T lhs, T rhs) {
return lhs / rhs;
}
inline bool div_value(bool lhs, bool rhs) {
LOG(FATAL) << "Attempted division of bool";
return false;
}
class SimpleIREvaluator : public CodeGen, public IRVisitor {
public:
using CodeGen::CodeGen;
@ -205,7 +226,7 @@ class SimpleIREvaluator : public CodeGen, public IRVisitor {
result_v[i] = lhs_v[i] * rhs_v[i];
break;
case IRNodeType::kDiv:
result_v[i] = lhs_v[i] / rhs_v[i];
result_v[i] = div_value(lhs_v[i], rhs_v[i]);
break;
case IRNodeType::kMod:
result_v[i] = mod_value(lhs_v[i], rhs_v[i]);