mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/35454 Differential Revision: D20665160 Pulled By: Krovatkin fbshipit-source-id: e04cbe92b2ee5a3288f3c4e5c83533bfea85bf85
This commit is contained in:
committed by
Facebook GitHub Bot
parent
930d218fbf
commit
9e22d15f14
@ -1,5 +1,6 @@
|
||||
#include <c10/util/Exception.h>
|
||||
#include <test/cpp/jit/tests.h>
|
||||
#include <test/cpp/tensorexpr/tests.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
@ -25,5 +26,17 @@ JIT_TEST_API void runJITCPPTests(bool runCuda) {
|
||||
testTorchSaveError();
|
||||
}
|
||||
#undef JIT_TEST
|
||||
|
||||
#define JIT_TEST(name) test##name();
|
||||
JIT_TEST_API void runTENSOREXPRCPPTests(bool runCuda) {
|
||||
TH_FORALL_TENSOREXPR_TESTS(JIT_TEST)
|
||||
if (runCuda) {
|
||||
#ifdef USE_CUDA
|
||||
TH_FORALL_TENSOREXPR_TESTS_CUDA(JIT_TEST)
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#undef JIT_TEST
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -3,6 +3,16 @@ set(TENSOREXPR_TEST_ROOT ${TORCH_ROOT}/test/cpp/tensorexpr)
|
||||
file(GLOB TENSOREXPR_TEST_SRCS ${TENSOREXPR_TEST_ROOT}/test_*.cpp)
|
||||
set(TENSOREXPR_TEST_SRCS ${TENSOREXPR_TEST_SRCS} PARENT_SCOPE)
|
||||
|
||||
# this is used for running cpp tests from python as part of test_jit.TestJit.test_tensorexpr_cpp
|
||||
set(TENSOREXPR_TEST_SRCS_WITH_PADDED ${TENSOREXPR_TEST_SRCS} ${TENSOREXPR_TEST_ROOT}/padded_buffer.cpp)
|
||||
if(NOT USE_CUDA)
|
||||
list(REMOVE_ITEM TENSOREXPR_TEST_SRCS_WITH_PADDED ${TENSOREXPR_TEST_ROOT}/test_cuda.cpp)
|
||||
endif()
|
||||
if(NOT USE_LLVM)
|
||||
list(REMOVE_ITEM TENSOREXPR_TEST_SRCS_WITH_PADDED ${TENSOREXPR_TEST_ROOT}/test_llvm.cpp)
|
||||
endif()
|
||||
set(TENSOREXPR_TEST_SRCS_WITH_PADDED ${TENSOREXPR_TEST_SRCS_WITH_PADDED} PARENT_SCOPE)
|
||||
|
||||
add_executable(test_tensorexpr
|
||||
${TORCH_ROOT}/test/cpp/common/main.cpp
|
||||
${TENSOREXPR_TEST_ROOT}/gtest.cpp
|
||||
|
@ -9,7 +9,7 @@ namespace jit {
|
||||
TEST(TensorExprTest, name) { \
|
||||
test##name(); \
|
||||
}
|
||||
TH_FORALL_TESTS(TENSOREXPR_GTEST)
|
||||
TH_FORALL_TENSOREXPR_TESTS(TENSOREXPR_GTEST)
|
||||
#undef TENSOREXPR_GTEST
|
||||
|
||||
#ifdef TORCH_ENABLE_LLVM
|
||||
@ -17,7 +17,7 @@ TH_FORALL_TESTS(TENSOREXPR_GTEST)
|
||||
TEST(TensorExprTest, name##_LLVM) { \
|
||||
test##name(); \
|
||||
}
|
||||
TH_FORALL_TESTS_LLVM(TENSOREXPR_GTEST_LLVM)
|
||||
TH_FORALL_TENSOREXPR_TESTS_LLVM(TENSOREXPR_GTEST_LLVM)
|
||||
#undef TENSOREXPR_GTEST_LLVM
|
||||
#endif
|
||||
|
||||
@ -26,7 +26,7 @@ TH_FORALL_TESTS_LLVM(TENSOREXPR_GTEST_LLVM)
|
||||
TEST(TensorExprTest, name##_CUDA) { \
|
||||
test##name(); \
|
||||
}
|
||||
TH_FORALL_TESTS_CUDA(TENSOREXPR_GTEST_CUDA)
|
||||
TH_FORALL_TENSOREXPR_TESTS_CUDA(TENSOREXPR_GTEST_CUDA)
|
||||
#undef TENSOREXPR_GTEST_CUDA
|
||||
#endif
|
||||
|
||||
|
118
test/cpp/tensorexpr/gtest_assert_float_eq.h
Normal file
118
test/cpp/tensorexpr/gtest_assert_float_eq.h
Normal file
@ -0,0 +1,118 @@
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
// Copyright 2005, Google Inc.
|
||||
// All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
// * Neither the name of Google Inc. nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from
|
||||
// this software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
//
|
||||
// The Google C++ Testing and Mocking Framework (Google Test)
|
||||
//
|
||||
// This header file declares functions and macros used internally by
|
||||
// Google Test. They are subject to change without notice.
|
||||
|
||||
using Bits = uint32_t;
|
||||
|
||||
// this avoids the "dereferencing type-punned pointer
|
||||
// will break strict-aliasing rules" error
|
||||
union Float {
|
||||
float float_;
|
||||
Bits bits_;
|
||||
};
|
||||
|
||||
// # of bits in a number.
|
||||
static const size_t kBitCount = 8*sizeof(Bits);
|
||||
// The mask for the sign bit.
|
||||
static const Bits kSignBitMask = static_cast<Bits>(1) << (kBitCount - 1);
|
||||
|
||||
// GOOGLETEST_CM0001 DO NOT DELETE
|
||||
|
||||
// Converts an integer from the sign-and-magnitude representation to
|
||||
// the biased representation. More precisely, let N be 2 to the
|
||||
// power of (kBitCount - 1), an integer x is represented by the
|
||||
// unsigned number x + N.
|
||||
//
|
||||
// For instance,
|
||||
//
|
||||
// -N + 1 (the most negative number representable using
|
||||
// sign-and-magnitude) is represented by 1;
|
||||
// 0 is represented by N; and
|
||||
// N - 1 (the biggest number representable using
|
||||
// sign-and-magnitude) is represented by 2N - 1.
|
||||
//
|
||||
// Read http://en.wikipedia.org/wiki/Signed_number_representations
|
||||
// for more details on signed number representations.
|
||||
static Bits SignAndMagnitudeToBiased(const Bits &sam) {
|
||||
if (kSignBitMask & sam) {
|
||||
// sam represents a negative number.
|
||||
return ~sam + 1;
|
||||
} else {
|
||||
// sam represents a positive number.
|
||||
return kSignBitMask | sam;
|
||||
}
|
||||
}
|
||||
|
||||
// Given two numbers in the sign-and-magnitude representation,
|
||||
// returns the distance between them as an unsigned number.
|
||||
static Bits DistanceBetweenSignAndMagnitudeNumbers(const Bits &sam1,
|
||||
const Bits &sam2) {
|
||||
const Bits biased1 = SignAndMagnitudeToBiased(sam1);
|
||||
const Bits biased2 = SignAndMagnitudeToBiased(sam2);
|
||||
return (biased1 >= biased2) ? (biased1 - biased2) : (biased2 - biased1);
|
||||
}
|
||||
|
||||
// How many ULP's (Units in the Last Place) we want to tolerate when
|
||||
// comparing two numbers. The larger the value, the more error we
|
||||
// allow. A 0 value means that two numbers must be exactly the same
|
||||
// to be considered equal.
|
||||
//
|
||||
// The maximum error of a single floating-point operation is 0.5
|
||||
// units in the last place. On Intel CPU's, all floating-point
|
||||
// calculations are done with 80-bit precision, while double has 64
|
||||
// bits. Therefore, 4 should be enough for ordinary use.
|
||||
//
|
||||
// See the following article for more details on ULP:
|
||||
// http://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/
|
||||
static const size_t kMaxUlps = 4;
|
||||
|
||||
// Returns true if and only if this number is at most kMaxUlps ULP's away
|
||||
// from rhs. In particular, this function:
|
||||
//
|
||||
// - returns false if either number is (or both are) NAN.
|
||||
// - treats really large numbers as almost equal to infinity.
|
||||
// - thinks +0.0 and -0.0 are 0 DLP's apart.
|
||||
inline bool AlmostEquals(float lhs, float rhs) {
|
||||
// The IEEE standard says that any comparison operation involving
|
||||
// a NAN must return false.
|
||||
if (std::isnan(lhs) || std::isnan(rhs))
|
||||
return false;
|
||||
|
||||
Float l = {lhs};
|
||||
Float r = {rhs};
|
||||
|
||||
return DistanceBetweenSignAndMagnitudeNumbers(l.bits_, r.bits_) <= kMaxUlps;
|
||||
}
|
@ -1,9 +1,6 @@
|
||||
#include "test/cpp/tensorexpr/padded_buffer.h"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
|
||||
namespace torch {
|
||||
|
@ -169,12 +169,20 @@ class PaddedBuffer : public PaddedBufferBase {
|
||||
// Verify the watermarks in the paddings are intact.
|
||||
void ValidateWatermark() const {
|
||||
for (int i = 0; i < kPaddingSize; i++) {
|
||||
EXPECT_EQ(data_[i], kPaddingValue)
|
||||
<< "left-side watermark broken: "
|
||||
<< "index: " << i << ", name: " << name();
|
||||
EXPECT_EQ(data_[i + total_size_ + kPaddingSize], kPaddingValue)
|
||||
<< "right-side watermark broken: "
|
||||
<< "index: " << i << ", name: " << name();
|
||||
ASSERT_EQ(
|
||||
data_[i],
|
||||
kPaddingValue,
|
||||
"left-side watermark broken: index: ",
|
||||
i,
|
||||
", name: ",
|
||||
name());
|
||||
ASSERT_EQ(
|
||||
data_[i + total_size_ + kPaddingSize],
|
||||
kPaddingValue,
|
||||
"right-side watermark broken: index: ",
|
||||
i,
|
||||
", name: ",
|
||||
name());
|
||||
}
|
||||
}
|
||||
|
||||
@ -183,9 +191,13 @@ class PaddedBuffer : public PaddedBufferBase {
|
||||
DCHECK(backup_data_.size() == data_.size())
|
||||
<< "Please make sure you have call Backup() before calling CheckBackup()";
|
||||
for (int i = 0; i < total_size_; i++) {
|
||||
EXPECT_EQ(data_[i + kPaddingSize], backup_data_[i + kPaddingSize])
|
||||
<< "mismatch against backup, "
|
||||
<< "index: " << i << ", name: " << name();
|
||||
ASSERT_EQ(
|
||||
data_[i + kPaddingSize],
|
||||
backup_data_[i + kPaddingSize],
|
||||
"mismatch against backup, index: ",
|
||||
i,
|
||||
", name: ",
|
||||
name());
|
||||
}
|
||||
}
|
||||
|
||||
@ -219,8 +231,8 @@ void ExpectAllEqual(const PaddedBuffer<T>& f1, const PaddedBuffer<T>& f2) {
|
||||
f1.ValidateWatermark();
|
||||
f2.ValidateWatermark();
|
||||
for (int i = 0; i < total_size; i++) {
|
||||
EXPECT_EQ(v1[kPaddingSize + i], v2[kPaddingSize + i])
|
||||
<< CompareErrorMsg(f1, f2, i);
|
||||
ASSERT_EQ(
|
||||
v1[kPaddingSize + i], v2[kPaddingSize + i], CompareErrorMsg(f1, f2, i));
|
||||
}
|
||||
}
|
||||
|
||||
@ -237,8 +249,11 @@ void ExpectAllNear(
|
||||
f1.ValidateWatermark();
|
||||
f2.ValidateWatermark();
|
||||
for (int i = 0; i < total_size; i++) {
|
||||
ASSERT_NEAR(v1[kPaddingSize + i], v2[kPaddingSize + i], abs_error);
|
||||
// << CompareErrorMsg(f1, f2, i);
|
||||
ASSERT_NEAR(
|
||||
v1[kPaddingSize + i],
|
||||
v2[kPaddingSize + i],
|
||||
abs_error,
|
||||
CompareErrorMsg(f1, f2, i));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,9 +1,10 @@
|
||||
#include <algorithm>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include "test/cpp/tensorexpr/test_base.h"
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include "test/cpp/tensorexpr/padded_buffer.h"
|
||||
#include "test/cpp/tensorexpr/test_base.h"
|
||||
#include "torch/csrc/jit/tensorexpr/ir_printer.h"
|
||||
|
||||
namespace torch {
|
||||
@ -34,8 +35,8 @@ void testATen_cast_Float() {
|
||||
ir_eval(a_v, b_v);
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
EXPECT_EQ(a_v(i), i) << "index: " << i;
|
||||
EXPECT_EQ(b_v(i), static_cast<float>(i)) << "index: " << i;
|
||||
ASSERT_EQ(a_v(i), i, "index: ", i);
|
||||
ASSERT_EQ(b_v(i), static_cast<float>(i), "index: ", i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -62,8 +63,8 @@ void testATennegInt() {
|
||||
ir_eval(a_v, b_v);
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
EXPECT_EQ(a_v(i), i) << "index: " << i;
|
||||
EXPECT_EQ(b_v(i), -static_cast<float>(i)) << "index: " << i;
|
||||
ASSERT_EQ(a_v(i), i, "index: ", i);
|
||||
ASSERT_EQ(b_v(i), -static_cast<float>(i), "index: ", i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -90,8 +91,8 @@ void testATennegFloat() {
|
||||
ir_eval(a_v, b_v);
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
EXPECT_EQ(a_v(i), i) << "index: " << i;
|
||||
EXPECT_EQ(b_v(i), -i) << "index: " << i;
|
||||
ASSERT_EQ(a_v(i), i, "index: ", i);
|
||||
ASSERT_EQ(b_v(i), -i, "index: ", i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -125,10 +126,10 @@ void testATenaddInt() {
|
||||
ir_eval(a_v, b_v, c_v, d_v);
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
EXPECT_EQ(a_v(i), i) << "index: " << i;
|
||||
EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i;
|
||||
EXPECT_EQ(c_v(i), 3 * i + 2) << "index: " << i;
|
||||
EXPECT_EQ(d_v(i), a_v(i) + b_v(i) * c_v(i)) << "index: " << i;
|
||||
ASSERT_EQ(a_v(i), i, "index: ", i);
|
||||
ASSERT_EQ(b_v(i), 2 * i + 1, "index: ", i);
|
||||
ASSERT_EQ(c_v(i), 3 * i + 2, "index: ", i);
|
||||
ASSERT_EQ(d_v(i), a_v(i) + b_v(i) * c_v(i), "index: ", i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -162,10 +163,10 @@ void testATenaddFloat() {
|
||||
ir_eval(a_v, b_v, c_v, d_v);
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
EXPECT_EQ(a_v(i), i) << "index: " << i;
|
||||
EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i;
|
||||
EXPECT_EQ(c_v(i), 3 * i + 2) << "index: " << i;
|
||||
EXPECT_EQ(d_v(i), a_v(i) + b_v(i) * c_v(i)) << "index: " << i;
|
||||
ASSERT_EQ(a_v(i), i, "index: ", i);
|
||||
ASSERT_EQ(b_v(i), 2 * i + 1, "index: ", i);
|
||||
ASSERT_EQ(c_v(i), 3 * i + 2, "index: ", i);
|
||||
ASSERT_EQ(d_v(i), a_v(i) + b_v(i) * c_v(i), "index: ", i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -199,10 +200,10 @@ void testATensubInt() {
|
||||
ir_eval(a_v, b_v, c_v, d_v);
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
EXPECT_EQ(a_v(i), i) << "index: " << i;
|
||||
EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i;
|
||||
EXPECT_EQ(c_v(i), 3 * i + 2) << "index: " << i;
|
||||
EXPECT_EQ(d_v(i), a_v(i) - b_v(i) * c_v(i)) << "index: " << i;
|
||||
ASSERT_EQ(a_v(i), i, "index: ", i);
|
||||
ASSERT_EQ(b_v(i), 2 * i + 1, "index: ", i);
|
||||
ASSERT_EQ(c_v(i), 3 * i + 2, "index: ", i);
|
||||
ASSERT_EQ(d_v(i), a_v(i) - b_v(i) * c_v(i), "index: ", i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -236,10 +237,10 @@ void testATensubFloat() {
|
||||
ir_eval(a_v, b_v, c_v, d_v);
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
EXPECT_EQ(a_v(i), i) << "index: " << i;
|
||||
EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i;
|
||||
EXPECT_EQ(c_v(i), 3 * i + 2) << "index: " << i;
|
||||
EXPECT_EQ(d_v(i), a_v(i) - b_v(i) * c_v(i)) << "index: " << i;
|
||||
ASSERT_EQ(a_v(i), i, "index: ", i);
|
||||
ASSERT_EQ(b_v(i), 2 * i + 1, "index: ", i);
|
||||
ASSERT_EQ(c_v(i), 3 * i + 2, "index: ", i);
|
||||
ASSERT_EQ(d_v(i), a_v(i) - b_v(i) * c_v(i), "index: ", i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -274,10 +275,10 @@ void testATenlerp() {
|
||||
ir_eval(a_v, b_v, c_v, d_v);
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
EXPECT_EQ(a_v(i), i) << "index: " << i;
|
||||
EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i;
|
||||
EXPECT_EQ(c_v(i), 3 * i + 2) << "index: " << i;
|
||||
EXPECT_EQ(d_v(i), a_v(i) + c_v(i) * (b_v(i) - a_v(i))) << "index: " << i;
|
||||
ASSERT_EQ(a_v(i), i, "index: ", i);
|
||||
ASSERT_EQ(b_v(i), 2 * i + 1, "index: ", i);
|
||||
ASSERT_EQ(c_v(i), 3 * i + 2, "index: ", i);
|
||||
ASSERT_EQ(d_v(i), a_v(i) + c_v(i) * (b_v(i) - a_v(i)), "index: ", i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -316,11 +317,11 @@ void testATenaddcmulInt() {
|
||||
ir_eval(a_v, b_v, c_v, d_v, e_v);
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
EXPECT_EQ(a_v(i), i) << "index: " << i;
|
||||
EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i;
|
||||
EXPECT_EQ(c_v(i), 3 * i + 2) << "index: " << i;
|
||||
EXPECT_EQ(d_v(i), 5 * i + 3) << "index: " << i;
|
||||
EXPECT_EQ(e_v(i), a_v(i) + b_v(i) * c_v(i) * d_v(i)) << "index: " << i;
|
||||
ASSERT_EQ(a_v(i), i, "index: ", i);
|
||||
ASSERT_EQ(b_v(i), 2 * i + 1, "index: ", i);
|
||||
ASSERT_EQ(c_v(i), 3 * i + 2, "index: ", i);
|
||||
ASSERT_EQ(d_v(i), 5 * i + 3, "index: ", i);
|
||||
ASSERT_EQ(e_v(i), a_v(i) + b_v(i) * c_v(i) * d_v(i), "index: ", i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -359,11 +360,11 @@ void testATenaddcmulFloat() {
|
||||
ir_eval(a_v, b_v, c_v, d_v, e_v);
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
EXPECT_EQ(a_v(i), i) << "index: " << i;
|
||||
EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i;
|
||||
EXPECT_EQ(c_v(i), 3 * i + 2) << "index: " << i;
|
||||
EXPECT_EQ(d_v(i), 5 * i + 3) << "index: " << i;
|
||||
EXPECT_FLOAT_EQ(e_v(i), a_v(i) + b_v(i) * c_v(i) * d_v(i)) << "index: " << i;
|
||||
ASSERT_EQ(a_v(i), i, "index: ", i);
|
||||
ASSERT_EQ(b_v(i), 2 * i + 1, "index: ", i);
|
||||
ASSERT_EQ(c_v(i), 3 * i + 2, "index: ", i);
|
||||
ASSERT_EQ(d_v(i), 5 * i + 3, "index: ", i);
|
||||
ASSERT_FLOAT_EQ(e_v(i), a_v(i) + b_v(i) * c_v(i) * d_v(i), "index: ", i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -393,9 +394,9 @@ void testATenmulInt() {
|
||||
ir_eval(a_v, b_v, c_v);
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
EXPECT_EQ(a_v(i), i) << "index: " << i;
|
||||
EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i;
|
||||
EXPECT_EQ(c_v(i), a_v(i) * b_v(i)) << "index: " << i;
|
||||
ASSERT_EQ(a_v(i), i, "index: ", i);
|
||||
ASSERT_EQ(b_v(i), 2 * i + 1, "index: ", i);
|
||||
ASSERT_EQ(c_v(i), a_v(i) * b_v(i), "index: ", i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -425,9 +426,9 @@ void testATenmulFloat() {
|
||||
ir_eval(a_v, b_v, c_v);
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
EXPECT_EQ(a_v(i), i) << "index: " << i;
|
||||
EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i;
|
||||
EXPECT_EQ(c_v(i), a_v(i) * b_v(i)) << "index: " << i;
|
||||
ASSERT_EQ(a_v(i), i, "index: ", i);
|
||||
ASSERT_EQ(b_v(i), 2 * i + 1, "index: ", i);
|
||||
ASSERT_EQ(c_v(i), a_v(i) * b_v(i), "index: ", i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -457,9 +458,9 @@ void testATendivInt() {
|
||||
ir_eval(a_v, b_v, c_v);
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
EXPECT_EQ(a_v(i), 2 * i + 1) << "index: " << i;
|
||||
EXPECT_EQ(b_v(i), i + 1) << "index: " << i;
|
||||
EXPECT_EQ(c_v(i), a_v(i) / b_v(i)) << "index: " << i;
|
||||
ASSERT_EQ(a_v(i), 2 * i + 1, "index: ", i);
|
||||
ASSERT_EQ(b_v(i), i + 1, "index: ", i);
|
||||
ASSERT_EQ(c_v(i), a_v(i) / b_v(i), "index: ", i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -489,9 +490,9 @@ void testATendivFloat() {
|
||||
ir_eval(a_v, b_v, c_v);
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
EXPECT_EQ(a_v(i), 2 * i + 1) << "index: " << i;
|
||||
EXPECT_EQ(b_v(i), i + 1) << "index: " << i;
|
||||
EXPECT_EQ(c_v(i), a_v(i) / b_v(i)) << "index: " << i;
|
||||
ASSERT_EQ(a_v(i), 2 * i + 1, "index: ", i);
|
||||
ASSERT_EQ(b_v(i), i + 1, "index: ", i);
|
||||
ASSERT_EQ(c_v(i), a_v(i) / b_v(i), "index: ", i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -521,9 +522,9 @@ void testATenmaxInt() {
|
||||
ir_eval(a_v, b_v, c_v);
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
EXPECT_EQ(a_v(i), i) << "index: " << i;
|
||||
EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i;
|
||||
EXPECT_EQ(c_v(i), std::max(a_v(i), b_v(i))) << "index: " << i;
|
||||
ASSERT_EQ(a_v(i), i, "index: ", i);
|
||||
ASSERT_EQ(b_v(i), 2 * i + 1, "index: ", i);
|
||||
ASSERT_EQ(c_v(i), std::max(a_v(i), b_v(i)), "index: ", i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -553,9 +554,9 @@ void testATenmaxFloat() {
|
||||
ir_eval(a_v, b_v, c_v);
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
EXPECT_EQ(a_v(i), i) << "index: " << i;
|
||||
EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i;
|
||||
EXPECT_EQ(c_v(i), std::fmax(a_v(i), b_v(i))) << "index: " << i;
|
||||
ASSERT_EQ(a_v(i), i, "index: ", i);
|
||||
ASSERT_EQ(b_v(i), 2 * i + 1, "index: ", i);
|
||||
ASSERT_EQ(c_v(i), std::fmax(a_v(i), b_v(i)), "index: ", i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -585,9 +586,9 @@ void testATenminInt() {
|
||||
ir_eval(a_v, b_v, c_v);
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
EXPECT_EQ(a_v(i), i) << "index: " << i;
|
||||
EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i;
|
||||
EXPECT_EQ(c_v(i), std::min(a_v(i), b_v(i))) << "index: " << i;
|
||||
ASSERT_EQ(a_v(i), i, "index: ", i);
|
||||
ASSERT_EQ(b_v(i), 2 * i + 1, "index: ", i);
|
||||
ASSERT_EQ(c_v(i), std::min(a_v(i), b_v(i)), "index: ", i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -617,9 +618,9 @@ void testATenminFloat() {
|
||||
ir_eval(a_v, b_v, c_v);
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
EXPECT_EQ(a_v(i), i) << "index: " << i;
|
||||
EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i;
|
||||
EXPECT_EQ(c_v(i), std::fmin(a_v(i), b_v(i))) << "index: " << i;
|
||||
ASSERT_EQ(a_v(i), i, "index: ", i);
|
||||
ASSERT_EQ(b_v(i), 2 * i + 1, "index: ", i);
|
||||
ASSERT_EQ(c_v(i), std::fmin(a_v(i), b_v(i)), "index: ", i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -650,9 +651,9 @@ void testATen_sigmoid_backward() {
|
||||
ir_eval(a_v, b_v, c_v);
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
EXPECT_EQ(a_v(i), i) << "index: " << i;
|
||||
EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i;
|
||||
EXPECT_EQ(c_v(i), a_v(i) * b_v(i) * (1.0f - b_v(i))) << "index: " << i;
|
||||
ASSERT_EQ(a_v(i), i, "index: ", i);
|
||||
ASSERT_EQ(b_v(i), 2 * i + 1, "index: ", i);
|
||||
ASSERT_EQ(c_v(i), a_v(i) * b_v(i) * (1.0f - b_v(i)), "index: ", i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -683,13 +684,13 @@ void testATen_tanh_backward() {
|
||||
ir_eval(a_v, b_v, c_v);
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
EXPECT_EQ(a_v(i), i) << "index: " << i;
|
||||
EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i;
|
||||
EXPECT_EQ(c_v(i), a_v(i) * (1.0f - (b_v(i) * b_v(i)))) << "index: " << i;
|
||||
ASSERT_EQ(a_v(i), i, "index: ", i);
|
||||
ASSERT_EQ(b_v(i), 2 * i + 1, "index: ", i);
|
||||
ASSERT_EQ(c_v(i), a_v(i) * (1.0f - (b_v(i) * b_v(i))), "index: ", i);
|
||||
}
|
||||
}
|
||||
|
||||
void testATenreciprocal() {
|
||||
void __ubsan_ignore_float_divide_by_zero__ testATenreciprocal() {
|
||||
KernelScope kernel_scope;
|
||||
const int kTotalSize = 128;
|
||||
Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)});
|
||||
@ -711,8 +712,8 @@ void testATenreciprocal() {
|
||||
ir_eval(a_v, b_v);
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
EXPECT_EQ(a_v(i), i) << "index: " << i;
|
||||
EXPECT_EQ(b_v(i), 1.0f / i) << "index: " << i;
|
||||
ASSERT_EQ(a_v(i), i, "index: ", i);
|
||||
ASSERT_EQ(b_v(i), 1.0f / i, "index: ", i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -738,8 +739,8 @@ void testATenreluInt() {
|
||||
ir_eval(a_v, b_v);
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
EXPECT_EQ(a_v(i), i - 64) << "index: " << i;
|
||||
EXPECT_EQ(b_v(i), std::max(a_v(i), 0)) << "index: " << i;
|
||||
ASSERT_EQ(a_v(i), i - 64, "index: ", i);
|
||||
ASSERT_EQ(b_v(i), std::max(a_v(i), 0), "index: ", i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -769,8 +770,8 @@ void testATenreluFloat() {
|
||||
ir_eval(a_v, b_v);
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
EXPECT_EQ(a_v(i), i - 64) << "index: " << i;
|
||||
EXPECT_EQ(b_v(i), std::fmax(a_v(i), 0)) << "index: " << i;
|
||||
ASSERT_EQ(a_v(i), i - 64, "index: ", i);
|
||||
ASSERT_EQ(b_v(i), std::fmax(a_v(i), 0), "index: ", i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -796,8 +797,8 @@ void testATenlogFloat() {
|
||||
ir_eval(a_v, b_v);
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
EXPECT_EQ(a_v(i), i + 10) << "index: " << i;
|
||||
EXPECT_EQ(b_v(i), std::log(a_v(i))) << "index: " << i;
|
||||
ASSERT_EQ(a_v(i), i + 10, "index: ", i);
|
||||
ASSERT_EQ(b_v(i), std::log(a_v(i)), "index: ", i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -823,8 +824,8 @@ void testATenlog10Float() {
|
||||
ir_eval(a_v, b_v);
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
EXPECT_EQ(a_v(i), i + 10) << "index: " << i;
|
||||
EXPECT_EQ(b_v(i), std::log10(a_v(i))) << "index: " << i;
|
||||
ASSERT_EQ(a_v(i), i + 10, "index: ", i);
|
||||
ASSERT_EQ(b_v(i), std::log10(a_v(i)), "index: ", i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -850,8 +851,8 @@ void testATenlog2Float() {
|
||||
ir_eval(a_v, b_v);
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
EXPECT_EQ(a_v(i), i + 10) << "index: " << i;
|
||||
EXPECT_EQ(b_v(i), std::log2(a_v(i))) << "index: " << i;
|
||||
ASSERT_EQ(a_v(i), i + 10, "index: ", i);
|
||||
ASSERT_EQ(b_v(i), std::log2(a_v(i)), "index: ", i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -877,8 +878,8 @@ void testATenexpFloat() {
|
||||
ir_eval(a_v, b_v);
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
EXPECT_EQ(a_v(i), i / 10.0f) << "index: " << i;
|
||||
EXPECT_EQ(b_v(i), std::exp(a_v(i))) << "index: " << i;
|
||||
ASSERT_EQ(a_v(i), i / 10.0f, "index: ", i);
|
||||
ASSERT_EQ(b_v(i), std::exp(a_v(i)), "index: ", i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -904,8 +905,8 @@ void testATenerfFloat() {
|
||||
ir_eval(a_v, b_v);
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
EXPECT_EQ(a_v(i), i / 10.0f) << "index: " << i;
|
||||
EXPECT_EQ(b_v(i), std::erf(a_v(i))) << "index: " << i;
|
||||
ASSERT_EQ(a_v(i), i / 10.0f, "index: ", i);
|
||||
ASSERT_EQ(b_v(i), std::erf(a_v(i)), "index: ", i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -931,8 +932,8 @@ void testATencosFloat() {
|
||||
ir_eval(a_v, b_v);
|
||||
|
||||
for (int i = 0; i < kTotalSize; ++i) {
|
||||
EXPECT_EQ(a_v(i), i / 10.0f) << "index: " << i;
|
||||
EXPECT_EQ(b_v(i), std::cos(a_v(i))) << "index: " << i;
|
||||
ASSERT_EQ(a_v(i), i / 10.0f, "index: ", i);
|
||||
ASSERT_EQ(b_v(i), std::cos(a_v(i)), "index: ", i);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,45 @@
|
||||
#pragma once
|
||||
|
||||
#if defined(USE_GTEST)
|
||||
#include <gtest/gtest.h>
|
||||
#include <test/cpp/common/support.h>
|
||||
#else
|
||||
#include "c10/util/Exception.h"
|
||||
#include "test/cpp/tensorexpr/gtest_assert_float_eq.h"
|
||||
#include <cmath>
|
||||
#define ASSERT_EQ(x, y, ...) TORCH_INTERNAL_ASSERT((x) == (y), __VA_ARGS__)
|
||||
#define ASSERT_FLOAT_EQ(x, y, ...) \
|
||||
TORCH_INTERNAL_ASSERT(AlmostEquals((x), (y)), __VA_ARGS__)
|
||||
#define ASSERT_NE(x, y, ...) TORCH_INTERNAL_ASSERT((x) != (y), __VA_ARGS__)
|
||||
#define ASSERT_GT(x, y, ...) TORCH_INTERNAL_ASSERT((x) > (y), __VA_ARGS__)
|
||||
#define ASSERT_GE(x, y, ...) TORCH_INTERNAL_ASSERT((x) >= (y), __VA_ARGS__)
|
||||
#define ASSERT_LT(x, y, ...) TORCH_INTERNAL_ASSERT((x) < (y), __VA_ARGS__)
|
||||
#define ASSERT_LE(x, y, ...) TORCH_INTERNAL_ASSERT((x) <= (y), __VA_ARGS__)
|
||||
|
||||
#define ASSERT_NEAR(x, y, a, ...) \
|
||||
TORCH_INTERNAL_ASSERT(std::fabs((x) - (y)) < (a), __VA_ARGS__)
|
||||
|
||||
#define ASSERT_TRUE TORCH_INTERNAL_ASSERT
|
||||
#define ASSERT_FALSE(x) ASSERT_TRUE(!(x))
|
||||
#define ASSERT_THROWS_WITH(statement, substring) \
|
||||
try { \
|
||||
(void)statement; \
|
||||
ASSERT_TRUE(false); \
|
||||
} catch (const std::exception& e) { \
|
||||
ASSERT_NE(std::string(e.what()).find(substring), std::string::npos); \
|
||||
}
|
||||
#define ASSERT_ANY_THROW(statement) \
|
||||
{ \
|
||||
bool threw = false; \
|
||||
try { \
|
||||
(void)statement; \
|
||||
} catch (const std::exception& e) { \
|
||||
threw = true; \
|
||||
} \
|
||||
ASSERT_TRUE(threw); \
|
||||
}
|
||||
|
||||
#endif // defined(USE_GTEST)
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
@ -15,8 +53,8 @@ void ExpectAllNear(
|
||||
const std::string& name = "") {
|
||||
ASSERT_EQ(v1.size(), v2.size());
|
||||
for (int i = 0; i < v1.size(); i++) {
|
||||
EXPECT_NEAR(v1[i], v2[i], threshold)
|
||||
<< "element index: " << i << ", name: " << name;
|
||||
ASSERT_NEAR(
|
||||
v1[i], v2[i], threshold, "element index: ", i, ", name: ", name);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -261,7 +261,7 @@ void testCudaTestRand01() {
|
||||
sum1 += v;
|
||||
sum2 += v * v;
|
||||
sum3 += v * v * v;
|
||||
EXPECT_TRUE(v >= 0 && v < 1) << "invalid value: " << i << ", " << v;
|
||||
ASSERT_TRUE(v >= 0 && v < 1, "invalid value: ", i, ", ", v);
|
||||
}
|
||||
sum1 /= N;
|
||||
sum2 /= N;
|
||||
@ -270,9 +270,9 @@ void testCudaTestRand01() {
|
||||
float sum2_mean = 1.f / 3;
|
||||
float sum3_mean = 1.f / 4;
|
||||
|
||||
EXPECT_NEAR(sum1, sum1_mean, 2e-2);
|
||||
EXPECT_NEAR(sum2, sum2_mean, 2e-2);
|
||||
EXPECT_NEAR(sum3, sum3_mean, 2e-2);
|
||||
ASSERT_NEAR(sum1, sum1_mean, 2e-2);
|
||||
ASSERT_NEAR(sum2, sum2_mean, 2e-2);
|
||||
ASSERT_NEAR(sum3, sum3_mean, 2e-2);
|
||||
cudaFree(c_dev);
|
||||
}
|
||||
|
||||
|
@ -27,7 +27,7 @@ void testExprBasicValueTest() {
|
||||
ExprHandle a = IntImm::make(2), b = IntImm::make(3);
|
||||
ExprHandle c = Add::make(a, b);
|
||||
SimpleIRExprEval eval(c);
|
||||
EXPECT_EQ(eval.value<int>(), 5);
|
||||
ASSERT_EQ(eval.value<int>(), 5);
|
||||
}
|
||||
|
||||
void testExprBasicValueTest02() {
|
||||
@ -38,7 +38,7 @@ void testExprBasicValueTest02() {
|
||||
ExprHandle d(5.0f);
|
||||
ExprHandle f = (a + b) - (c + d);
|
||||
SimpleIRExprEval eval(f);
|
||||
EXPECT_EQ(eval.value<float>(), -4.0f);
|
||||
ASSERT_EQ(eval.value<float>(), -4.0f);
|
||||
}
|
||||
|
||||
void testExprLetTest01() {
|
||||
@ -48,7 +48,7 @@ void testExprLetTest01() {
|
||||
ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f));
|
||||
ExprHandle result = Let::make(x, ExprHandle(3.f), body);
|
||||
SimpleIRExprEval eval(result);
|
||||
EXPECT_EQ(eval.value<float>(), 2 + (3 * 3 + 4));
|
||||
ASSERT_EQ(eval.value<float>(), 2 + (3 * 3 + 4));
|
||||
}
|
||||
|
||||
void testExprLetTest02() {
|
||||
@ -61,7 +61,7 @@ void testExprLetTest02() {
|
||||
ExprHandle e1 = Let::make(x, ExprHandle(3.f), body);
|
||||
ExprHandle e2 = Let::make(y, ExprHandle(6.f), e1);
|
||||
SimpleIRExprEval eval(e2);
|
||||
EXPECT_EQ(eval.value<float>(), 2 + (3 * 3 + 4 * 6));
|
||||
ASSERT_EQ(eval.value<float>(), 2 + (3 * 3 + 4 * 6));
|
||||
}
|
||||
|
||||
void testExprLetStmtTest01() {
|
||||
@ -97,7 +97,7 @@ void testExprIntTest() {
|
||||
ExprHandle body = ExprHandle(2) + (x * ExprHandle(3) + ExprHandle(4));
|
||||
ExprHandle result = Let::make(x, ExprHandle(3), body);
|
||||
SimpleIRExprEval eval(result);
|
||||
EXPECT_EQ(eval.value<int>(), 2 + (3 * 3 + 4));
|
||||
ASSERT_EQ(eval.value<int>(), 2 + (3 * 3 + 4));
|
||||
}
|
||||
|
||||
void testExprFloatTest() {
|
||||
@ -108,7 +108,7 @@ void testExprFloatTest() {
|
||||
ExprHandle((float)2) + (x * ExprHandle((float)3) + ExprHandle((float)4));
|
||||
ExprHandle result = Let::make(x, ExprHandle((float)3), body);
|
||||
SimpleIRExprEval eval(result);
|
||||
EXPECT_EQ(eval.value<float>(), 2 + (3 * 3 + 4));
|
||||
ASSERT_EQ(eval.value<float>(), 2 + (3 * 3 + 4));
|
||||
}
|
||||
|
||||
void testExprByteTest() {
|
||||
@ -119,7 +119,7 @@ void testExprByteTest() {
|
||||
(x * ExprHandle((uint8_t)3) + ExprHandle((uint8_t)4));
|
||||
ExprHandle result = Let::make(x, ExprHandle((uint8_t)3), body);
|
||||
SimpleIRExprEval eval(result);
|
||||
EXPECT_EQ(eval.value<uint8_t>(), 2 + (3 * 3 + 4));
|
||||
ASSERT_EQ(eval.value<uint8_t>(), 2 + (3 * 3 + 4));
|
||||
}
|
||||
|
||||
void testExprCharTest() {
|
||||
@ -130,7 +130,7 @@ void testExprCharTest() {
|
||||
(x * ExprHandle((int8_t)3) + ExprHandle((int8_t)4));
|
||||
ExprHandle result = Let::make(x, ExprHandle((int8_t)3), body);
|
||||
SimpleIRExprEval eval(result);
|
||||
EXPECT_EQ(eval.value<int8_t>(), 2 + (3 * 3 + 4));
|
||||
ASSERT_EQ(eval.value<int8_t>(), 2 + (3 * 3 + 4));
|
||||
}
|
||||
|
||||
void testExprShortTest() {
|
||||
@ -141,7 +141,7 @@ void testExprShortTest() {
|
||||
(x * ExprHandle((int16_t)3) + ExprHandle((int16_t)4));
|
||||
ExprHandle result = Let::make(x, ExprHandle((int16_t)3), body);
|
||||
SimpleIRExprEval eval(result);
|
||||
EXPECT_EQ(eval.value<int16_t>(), 2 + (3 * 3 + 4));
|
||||
ASSERT_EQ(eval.value<int16_t>(), 2 + (3 * 3 + 4));
|
||||
}
|
||||
|
||||
void testExprLongTest() {
|
||||
@ -152,7 +152,7 @@ void testExprLongTest() {
|
||||
(x * ExprHandle((int64_t)3) + ExprHandle((int64_t)4));
|
||||
ExprHandle result = Let::make(x, ExprHandle((int64_t)3), body);
|
||||
SimpleIRExprEval eval(result);
|
||||
EXPECT_EQ(eval.value<int64_t>(), 2 + (3 * 3 + 4));
|
||||
ASSERT_EQ(eval.value<int64_t>(), 2 + (3 * 3 + 4));
|
||||
}
|
||||
|
||||
void testExprHalfTest() {
|
||||
@ -163,7 +163,7 @@ void testExprHalfTest() {
|
||||
(x * ExprHandle((at::Half)3) + ExprHandle((at::Half)4));
|
||||
ExprHandle result = Let::make(x, ExprHandle((at::Half)3), body);
|
||||
SimpleIRExprEval eval(result);
|
||||
EXPECT_EQ(eval.value<at::Half>(), 2 + (3 * 3 + 4));
|
||||
ASSERT_EQ(eval.value<at::Half>(), 2 + (3 * 3 + 4));
|
||||
}
|
||||
|
||||
void testExprDoubleTest() {
|
||||
@ -174,7 +174,7 @@ void testExprDoubleTest() {
|
||||
(x * ExprHandle((double)3) + ExprHandle((double)4));
|
||||
ExprHandle result = Let::make(x, ExprHandle((double)3), body);
|
||||
SimpleIRExprEval eval(result);
|
||||
EXPECT_EQ(eval.value<double>(), 2 + (3 * 3 + 4));
|
||||
ASSERT_EQ(eval.value<double>(), 2 + (3 * 3 + 4));
|
||||
}
|
||||
void testExprVectorAdd01() {
|
||||
KernelScope kernel_scope;
|
||||
@ -211,9 +211,9 @@ void testExprVectorAdd01() {
|
||||
Broadcast::make(1, kVectorSize));
|
||||
Stmt* stmt = For::make(index, 0, kVectorCount, store_c);
|
||||
|
||||
EXPECT_EQ(load_a.dtype(), Dtype(kFloat, kVectorSize));
|
||||
EXPECT_EQ(load_b.dtype(), Dtype(kFloat, kVectorSize));
|
||||
EXPECT_EQ(value.dtype(), Dtype(kFloat, kVectorSize));
|
||||
ASSERT_EQ(load_a.dtype(), Dtype(kFloat, kVectorSize));
|
||||
ASSERT_EQ(load_b.dtype(), Dtype(kFloat, kVectorSize));
|
||||
ASSERT_EQ(value.dtype(), Dtype(kFloat, kVectorSize));
|
||||
|
||||
PaddedBuffer<float> a_v(kTotalSize);
|
||||
PaddedBuffer<float> b_v(kTotalSize);
|
||||
@ -359,7 +359,7 @@ void testExprUnaryMath01() {
|
||||
ExprHandle v = test_config.func(ExprHandle(input_v));
|
||||
float v_ref = test_config.ref_func(input_v);
|
||||
SimpleIRExprEval eval(v);
|
||||
EXPECT_NEAR(eval.value<float>(), v_ref, 1e-6) << "fail: " << v;
|
||||
ASSERT_NEAR(eval.value<float>(), v_ref, 1e-6, "fail: ", v);
|
||||
}
|
||||
}
|
||||
|
||||
@ -383,7 +383,7 @@ void testExprBinaryMath01() {
|
||||
ExprHandle v_expr = test_config.func(ExprHandle(v1), ExprHandle(v2));
|
||||
float v_ref = test_config.ref_func(v1, v2);
|
||||
SimpleIRExprEval eval(v_expr);
|
||||
EXPECT_NEAR(eval.value<float>(), v_ref, 1e-6) << "fail: " << v_expr;
|
||||
ASSERT_NEAR(eval.value<float>(), v_ref, 1e-6, "fail: ", v_expr);
|
||||
}
|
||||
}
|
||||
|
||||
@ -396,7 +396,7 @@ void testExprBitwiseOps() {
|
||||
ExprHandle f = (((a ^ (b << 1)) & c) >> 2) | d;
|
||||
|
||||
SimpleIRExprEval eval(f);
|
||||
EXPECT_EQ(eval.value<int>(), 11);
|
||||
ASSERT_EQ(eval.value<int>(), 11);
|
||||
}
|
||||
|
||||
void testExprDynamicShapeAdd() {
|
||||
|
@ -18,7 +18,7 @@ void testIRPrinterBasicValueTest() {
|
||||
|
||||
std::stringstream ss;
|
||||
ss << c;
|
||||
EXPECT_EQ(ss.str(), "2 + 3");
|
||||
ASSERT_EQ(ss.str(), "2 + 3");
|
||||
}
|
||||
|
||||
void testIRPrinterBasicValueTest02() {
|
||||
@ -31,7 +31,7 @@ void testIRPrinterBasicValueTest02() {
|
||||
|
||||
std::stringstream ss;
|
||||
ss << f;
|
||||
EXPECT_EQ(ss.str(), "(2.f + 3.f) - (4.f + 5.f)");
|
||||
ASSERT_EQ(ss.str(), "(2.f + 3.f) - (4.f + 5.f)");
|
||||
}
|
||||
|
||||
void testIRPrinterLetTest01() {
|
||||
@ -43,7 +43,7 @@ void testIRPrinterLetTest01() {
|
||||
|
||||
std::stringstream ss;
|
||||
ss << result;
|
||||
EXPECT_EQ(ss.str(), "let x = 3.f in 2.f + (x * 3.f + 4.f)");
|
||||
ASSERT_EQ(ss.str(), "let x = 3.f in 2.f + (x * 3.f + 4.f)");
|
||||
}
|
||||
|
||||
void testIRPrinterLetTest02() {
|
||||
@ -58,7 +58,7 @@ void testIRPrinterLetTest02() {
|
||||
|
||||
std::stringstream ss;
|
||||
ss << e2;
|
||||
EXPECT_EQ(
|
||||
ASSERT_EQ(
|
||||
ss.str(), "let y = 6.f in (let x = 3.f in 2.f + (x * 3.f + 4.f * y))");
|
||||
}
|
||||
|
||||
@ -74,7 +74,7 @@ void testIRPrinterCastTest() {
|
||||
|
||||
std::stringstream ss;
|
||||
ss << e2;
|
||||
EXPECT_EQ(
|
||||
ASSERT_EQ(
|
||||
ss.str(),
|
||||
"let y = 6.f in (let x = int(3.f) in 2.f + (x * 3.f + 4.f * y))");
|
||||
}
|
||||
|
@ -38,9 +38,9 @@ using LLVMExprEval = ExprEval<LLVMCodeGen>;
|
||||
auto a = Name##Imm::make(Val); \
|
||||
LLVMExprEval cg(a); \
|
||||
if (std::is_floating_point<decltype(Val)>()) { \
|
||||
EXPECT_NEAR(cg.value<Type>(), Val, 0.1); \
|
||||
ASSERT_NEAR(cg.value<Type>(), Val, 0.1); \
|
||||
} else { \
|
||||
EXPECT_EQ(cg.value<Type>(), Val); \
|
||||
ASSERT_EQ(cg.value<Type>(), Val); \
|
||||
} \
|
||||
}
|
||||
TEST_LLVM_SCALAR_TYPES(IMM_TEST)
|
||||
@ -54,9 +54,9 @@ TEST_LLVM_SCALAR_TYPES(IMM_TEST)
|
||||
auto c = Add::make(a, b); \
|
||||
LLVMExprEval cg(c); \
|
||||
if (std::is_floating_point<decltype(Val)>()) { \
|
||||
EXPECT_NEAR(cg.value<Type>(), Val * 3, 0.1); \
|
||||
ASSERT_NEAR(cg.value<Type>(), Val * 3, 0.1); \
|
||||
} else { \
|
||||
EXPECT_EQ(cg.value<Type>(), Val * 3); \
|
||||
ASSERT_EQ(cg.value<Type>(), Val * 3); \
|
||||
} \
|
||||
}
|
||||
TEST_LLVM_SCALAR_TYPES(ADD_TEST)
|
||||
@ -70,9 +70,9 @@ TEST_LLVM_SCALAR_TYPES(ADD_TEST)
|
||||
auto c = Sub::make(a, b); \
|
||||
LLVMExprEval cg(c); \
|
||||
if (std::is_floating_point<decltype(Val)>()) { \
|
||||
EXPECT_NEAR(cg.value<Type>(), Val, 0.1); \
|
||||
ASSERT_NEAR(cg.value<Type>(), Val, 0.1); \
|
||||
} else { \
|
||||
EXPECT_EQ(cg.value<Type>(), Val); \
|
||||
ASSERT_EQ(cg.value<Type>(), Val); \
|
||||
} \
|
||||
}
|
||||
TEST_LLVM_SCALAR_TYPES(SUB_TEST)
|
||||
@ -86,9 +86,9 @@ TEST_LLVM_SCALAR_TYPES(SUB_TEST)
|
||||
auto c = Mul::make(a, b); \
|
||||
LLVMExprEval cg(c); \
|
||||
if (std::is_floating_point<decltype(Val)>()) { \
|
||||
EXPECT_NEAR(cg.value<Type>(), Val * 4, 0.1); \
|
||||
ASSERT_NEAR(cg.value<Type>(), Val * 4, 0.1); \
|
||||
} else { \
|
||||
EXPECT_EQ(cg.value<Type>(), Val * 4); \
|
||||
ASSERT_EQ(cg.value<Type>(), Val * 4); \
|
||||
} \
|
||||
}
|
||||
TEST_LLVM_SCALAR_TYPES(MUL_TEST)
|
||||
@ -102,9 +102,9 @@ TEST_LLVM_SCALAR_TYPES(MUL_TEST)
|
||||
auto c = Div::make(a, b); \
|
||||
LLVMExprEval cg(c); \
|
||||
if (std::is_floating_point<decltype(Val)>()) { \
|
||||
EXPECT_NEAR(cg.value<Type>(), 2, 0.1); \
|
||||
ASSERT_NEAR(cg.value<Type>(), 2, 0.1); \
|
||||
} else { \
|
||||
EXPECT_EQ(cg.value<Type>(), 2); \
|
||||
ASSERT_EQ(cg.value<Type>(), 2); \
|
||||
} \
|
||||
}
|
||||
TEST_LLVM_SCALAR_TYPES(DIV_TEST)
|
||||
@ -115,7 +115,7 @@ void testLLVMIntToFloatCastTest() {
|
||||
auto a = IntImm::make(2);
|
||||
auto b = Cast::make(kFloat, a);
|
||||
LLVMExprEval cg(b, {});
|
||||
EXPECT_EQ(cg.value<float>(), 2.0);
|
||||
ASSERT_EQ(cg.value<float>(), 2.0);
|
||||
}
|
||||
|
||||
void testLLVMFloatToIntCastTest() {
|
||||
@ -123,7 +123,7 @@ void testLLVMFloatToIntCastTest() {
|
||||
auto a = FloatImm::make(2.0);
|
||||
auto b = Cast::make(kInt, a);
|
||||
LLVMExprEval cg(b);
|
||||
EXPECT_EQ(cg.value<int>(), 2);
|
||||
ASSERT_EQ(cg.value<int>(), 2);
|
||||
}
|
||||
|
||||
void testLLVMIntToLongCastTest() {
|
||||
@ -131,7 +131,7 @@ void testLLVMIntToLongCastTest() {
|
||||
auto a = IntImm::make(12345);
|
||||
auto b = Cast::make(kLong, a);
|
||||
LLVMExprEval cg(b);
|
||||
EXPECT_EQ(cg.value<int64_t>(), 12345);
|
||||
ASSERT_EQ(cg.value<int64_t>(), 12345);
|
||||
}
|
||||
|
||||
void testLLVMByteToCharCastTest() {
|
||||
@ -139,7 +139,7 @@ void testLLVMByteToCharCastTest() {
|
||||
auto a = ByteImm::make(250);
|
||||
auto b = Cast::make(kChar, a);
|
||||
LLVMExprEval cg(b);
|
||||
EXPECT_EQ(cg.value<int8_t>(), (int8_t)250);
|
||||
ASSERT_EQ(cg.value<int8_t>(), (int8_t)250);
|
||||
}
|
||||
|
||||
void testLLVMHalfToLongCastTest() {
|
||||
@ -147,7 +147,7 @@ void testLLVMHalfToLongCastTest() {
|
||||
auto a = HalfImm::make(2.0);
|
||||
auto b = Cast::make(kLong, a);
|
||||
LLVMExprEval cg(b);
|
||||
EXPECT_EQ(cg.value<int64_t>(), 2);
|
||||
ASSERT_EQ(cg.value<int64_t>(), 2);
|
||||
}
|
||||
|
||||
void testLLVMByteToDoubleCastTest() {
|
||||
@ -155,7 +155,7 @@ void testLLVMByteToDoubleCastTest() {
|
||||
auto a = ByteImm::make(2);
|
||||
auto b = Cast::make(kDouble, a);
|
||||
LLVMExprEval cg(b);
|
||||
EXPECT_EQ(cg.value<double>(), 2);
|
||||
ASSERT_EQ(cg.value<double>(), 2);
|
||||
}
|
||||
|
||||
void testLLVMLetTest01() {
|
||||
@ -165,7 +165,7 @@ void testLLVMLetTest01() {
|
||||
ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f));
|
||||
ExprHandle result = Let::make(x, ExprHandle(3.f), body);
|
||||
LLVMExprEval cg(result, {});
|
||||
EXPECT_EQ(cg.value<float>(), 2.f + (3.f * 3.f + 4.f));
|
||||
ASSERT_EQ(cg.value<float>(), 2.f + (3.f * 3.f + 4.f));
|
||||
}
|
||||
|
||||
void testLLVMLetTest02() {
|
||||
@ -178,7 +178,7 @@ void testLLVMLetTest02() {
|
||||
ExprHandle e1 = Let::make(x, ExprHandle(3.f), body);
|
||||
ExprHandle e2 = Let::make(y, ExprHandle(6.f), e1);
|
||||
LLVMExprEval cg(e2, {});
|
||||
EXPECT_EQ(cg.value<float>(), 2.f + (3.f * 3.f + 4.f * 6.f));
|
||||
ASSERT_EQ(cg.value<float>(), 2.f + (3.f * 3.f + 4.f * 6.f));
|
||||
}
|
||||
|
||||
void testLLVMLetTestMultitype() {
|
||||
@ -191,7 +191,7 @@ void testLLVMLetTestMultitype() {
|
||||
ExprHandle e1 = Let::make(x, ExprHandle((uint8_t)3), body);
|
||||
ExprHandle e2 = Let::make(y, ExprHandle((at::Half)6.f), e1);
|
||||
LLVMExprEval cg(e2, {});
|
||||
EXPECT_EQ(cg.value<double>(), 2.f + (3 * 3 + 4 * 6.f));
|
||||
ASSERT_EQ(cg.value<double>(), 2.f + (3 * 3 + 4 * 6.f));
|
||||
}
|
||||
|
||||
void testLLVMBufferTest() {
|
||||
@ -201,7 +201,7 @@ void testLLVMBufferTest() {
|
||||
std::vector<void*> args({v.data()});
|
||||
auto rv = IntImm::make(0);
|
||||
LLVMExprEval cg(rv, {a});
|
||||
EXPECT_EQ(cg.value<int>(args), 0);
|
||||
ASSERT_EQ(cg.value<int>(args), 0);
|
||||
}
|
||||
|
||||
void testLLVMBlockTest() {
|
||||
@ -217,9 +217,9 @@ void testLLVMBlockTest() {
|
||||
});
|
||||
|
||||
LLVMCodeGen cg(block, {a});
|
||||
EXPECT_EQ(cg.value<int>(args), 0);
|
||||
EXPECT_EQ(v[0], 4);
|
||||
EXPECT_EQ(v[1], 4);
|
||||
ASSERT_EQ(cg.value<int>(args), 0);
|
||||
ASSERT_EQ(v[0], 4);
|
||||
ASSERT_EQ(v[1], 4);
|
||||
}
|
||||
|
||||
void testLLVMLoadStoreTest() {
|
||||
@ -236,9 +236,9 @@ void testLLVMLoadStoreTest() {
|
||||
IntImm::make(1));
|
||||
LLVMCodeGen cg(store, {a, b});
|
||||
std::vector<void*> args({a_buffer.data(), b_buffer.data()});
|
||||
EXPECT_EQ(cg.value<int>(args), 0);
|
||||
EXPECT_EQ(a_buffer[0], 42);
|
||||
EXPECT_EQ(b_buffer[0], 42);
|
||||
ASSERT_EQ(cg.value<int>(args), 0);
|
||||
ASSERT_EQ(a_buffer[0], 42);
|
||||
ASSERT_EQ(b_buffer[0], 42);
|
||||
}
|
||||
|
||||
void testLLVMIfThenElseTest() {
|
||||
@ -260,9 +260,9 @@ void testLLVMIfThenElseTest() {
|
||||
IntImm::make(1));
|
||||
LLVMCodeGen cg(store, {a, b, c});
|
||||
std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
|
||||
EXPECT_EQ(cg.value<int>(args), 0);
|
||||
EXPECT_EQ(a_buffer[0], 42);
|
||||
EXPECT_EQ(b_buffer[0], 42);
|
||||
ASSERT_EQ(cg.value<int>(args), 0);
|
||||
ASSERT_EQ(a_buffer[0], 42);
|
||||
ASSERT_EQ(b_buffer[0], 42);
|
||||
}
|
||||
|
||||
void testLLVMVecLoadStoreTest() {
|
||||
@ -279,22 +279,22 @@ void testLLVMVecLoadStoreTest() {
|
||||
Broadcast::make(IntImm::make(1), 4));
|
||||
LLVMCodeGen cg(store, {a, b});
|
||||
std::vector<void*> args({a_buffer.data(), b_buffer.data()});
|
||||
EXPECT_EQ(cg.value<int>(args), 0);
|
||||
EXPECT_EQ(a_buffer[0], 1);
|
||||
EXPECT_EQ(a_buffer[1], 1);
|
||||
EXPECT_EQ(a_buffer[2], 1);
|
||||
EXPECT_EQ(a_buffer[3], 1);
|
||||
EXPECT_EQ(b_buffer[0], 1);
|
||||
EXPECT_EQ(b_buffer[1], 1);
|
||||
EXPECT_EQ(b_buffer[2], 1);
|
||||
EXPECT_EQ(b_buffer[3], 1);
|
||||
ASSERT_EQ(cg.value<int>(args), 0);
|
||||
ASSERT_EQ(a_buffer[0], 1);
|
||||
ASSERT_EQ(a_buffer[1], 1);
|
||||
ASSERT_EQ(a_buffer[2], 1);
|
||||
ASSERT_EQ(a_buffer[3], 1);
|
||||
ASSERT_EQ(b_buffer[0], 1);
|
||||
ASSERT_EQ(b_buffer[1], 1);
|
||||
ASSERT_EQ(b_buffer[2], 1);
|
||||
ASSERT_EQ(b_buffer[3], 1);
|
||||
}
|
||||
|
||||
#define FLOAT_INTRINSICS_TEST(Name, Lanes) \
|
||||
void testLLVMVecFloat_##Name##Lane##Lanes##Test() { \
|
||||
void testLLVMVecFloat_##Name##Lane##Lanes##Test() { \
|
||||
KernelScope kernel_scope; \
|
||||
Buffer a(VarHandle("A", kHandle), kFloat, {1}); \
|
||||
Buffer b(VarHandle("B", kHandle), kFloat, {1}); \
|
||||
Buffer a(VarHandle("A", kHandle), kFloat, {1}); \
|
||||
Buffer b(VarHandle("B", kHandle), kFloat, {1}); \
|
||||
float val = 0.5f; \
|
||||
std::vector<float> a_buffer(Lanes, val); \
|
||||
std::vector<float> b_buffer(Lanes, val); \
|
||||
@ -309,9 +309,9 @@ void testLLVMVecLoadStoreTest() {
|
||||
LLVMCodeGen cg(store, {a, b}); \
|
||||
std::vector<void*> args({a_buffer.data(), b_buffer.data()}); \
|
||||
float ref = std::Name(0.5f); \
|
||||
EXPECT_EQ(cg.value<int>(args), 0); \
|
||||
ASSERT_EQ(cg.value<int>(args), 0); \
|
||||
for (int i = 0; i < Lanes; i++) { \
|
||||
EXPECT_FLOAT_EQ(a_buffer[i], val); \
|
||||
ASSERT_FLOAT_EQ(a_buffer[i], val); \
|
||||
} \
|
||||
} // namespace jit
|
||||
FLOAT_INTRINSICS_TEST(erf, 4)
|
||||
@ -336,14 +336,14 @@ FLOAT_INTRINSICS_TEST(expm1, 8)
|
||||
FLOAT_INTRINSICS_TEST(lgamma, 8)
|
||||
#undef FLOAT_INTRINSICS_TEST
|
||||
|
||||
#define DOUBLE_INTRINSICS_TEST(Name, Lanes) \
|
||||
void testLLVMVecDouble_##Name##Lane##Lanes##Test() { \
|
||||
#define DOUBLE_INTRINSICS_TEST(Name, Lanes) \
|
||||
void testLLVMVecDouble_##Name##Lane##Lanes##Test() { \
|
||||
KernelScope kernel_scope; \
|
||||
Buffer a(VarHandle("A", kHandle), kDouble, {1}); \
|
||||
Buffer b(VarHandle("B", kHandle), kDouble, {1}); \
|
||||
Buffer a(VarHandle("A", kHandle), kDouble, {1}); \
|
||||
Buffer b(VarHandle("B", kHandle), kDouble, {1}); \
|
||||
float val = 0.5f; \
|
||||
std::vector<double> a_buffer(Lanes, val); \
|
||||
std::vector<double> b_buffer(Lanes, val); \
|
||||
std::vector<double> a_buffer(Lanes, val); \
|
||||
std::vector<double> b_buffer(Lanes, val); \
|
||||
auto store = Store::make( \
|
||||
b, \
|
||||
Ramp::make(0, 1, Lanes), \
|
||||
@ -355,9 +355,9 @@ FLOAT_INTRINSICS_TEST(lgamma, 8)
|
||||
LLVMCodeGen cg(store, {a, b}); \
|
||||
std::vector<void*> args({a_buffer.data(), b_buffer.data()}); \
|
||||
float ref = std::Name(0.5f); \
|
||||
EXPECT_EQ(cg.value<int>(args), 0); \
|
||||
ASSERT_EQ(cg.value<int>(args), 0); \
|
||||
for (int i = 0; i < Lanes; i++) { \
|
||||
EXPECT_FLOAT_EQ(a_buffer[i], val); \
|
||||
ASSERT_FLOAT_EQ(a_buffer[i], val); \
|
||||
} \
|
||||
} // namespace jit
|
||||
DOUBLE_INTRINSICS_TEST(erf, 2)
|
||||
@ -395,7 +395,8 @@ void testLLVMVectorizerLoadStoreTest() {
|
||||
Stmt* s = l.root_stmt();
|
||||
l.vectorize(*dynamic_cast<Block*>(s)->stmts().begin());
|
||||
|
||||
EXPECT_TRUE(dynamic_cast<For*>(*dynamic_cast<Block*>(s)->stmts().begin()) == nullptr);
|
||||
ASSERT_TRUE(
|
||||
dynamic_cast<For*>(*dynamic_cast<Block*>(s)->stmts().begin()) == nullptr);
|
||||
|
||||
LLVMCodeGen cg(s, {a, c_buf});
|
||||
|
||||
@ -992,7 +993,7 @@ void testLLVMStoreFloat() {
|
||||
LLVMCodeGen cg(expr, {result});
|
||||
std::vector<void*> args({result_buffer.data()});
|
||||
ASSERT_EQ(cg.value<int>(args), 0);
|
||||
EXPECT_EQ(result_buffer[0], 3.14f);
|
||||
ASSERT_EQ(result_buffer[0], 3.14f);
|
||||
}
|
||||
|
||||
void testLLVMSimpleMath01() {
|
||||
@ -1083,7 +1084,7 @@ void testLLVMBitwiseOps() {
|
||||
ExprHandle f = (((a ^ (b << 1)) & c) >> 2) | d;
|
||||
LLVMExprEval cg(f);
|
||||
|
||||
EXPECT_EQ(cg.value<int>(), 11);
|
||||
ASSERT_EQ(cg.value<int>(), 11);
|
||||
}
|
||||
|
||||
void testLLVMDynamicShapeAdd() {
|
||||
|
@ -18,11 +18,11 @@ void testConstantFoldSimple() {
|
||||
ExprHandle f = (a + b);
|
||||
|
||||
ExprHandle newF = IRSimplifier::simplify(f);
|
||||
EXPECT_NE(newF.AsNode<FloatImm>(), nullptr);
|
||||
EXPECT_EQ(newF.AsNode<FloatImm>()->value(), 5);
|
||||
ASSERT_NE(newF.AsNode<FloatImm>(), nullptr);
|
||||
ASSERT_EQ(newF.AsNode<FloatImm>()->value(), 5);
|
||||
|
||||
SimpleIRExprEval eval(newF);
|
||||
EXPECT_EQ(eval.value<float>(), 5.f);
|
||||
ASSERT_EQ(eval.value<float>(), 5.f);
|
||||
}
|
||||
|
||||
void testConstantFoldTwoLayer() {
|
||||
@ -34,11 +34,11 @@ void testConstantFoldTwoLayer() {
|
||||
ExprHandle f = (a + b) - (c + d);
|
||||
|
||||
ExprHandle newF = IRSimplifier::simplify(f);
|
||||
EXPECT_NE(newF.AsNode<FloatImm>(), nullptr);
|
||||
EXPECT_EQ(newF.AsNode<FloatImm>()->value(), -4);
|
||||
ASSERT_NE(newF.AsNode<FloatImm>(), nullptr);
|
||||
ASSERT_EQ(newF.AsNode<FloatImm>()->value(), -4);
|
||||
|
||||
SimpleIRExprEval eval(newF);
|
||||
EXPECT_EQ(eval.value<float>(), -4.f);
|
||||
ASSERT_EQ(eval.value<float>(), -4.f);
|
||||
}
|
||||
|
||||
void testConstantFoldShifts() {
|
||||
@ -49,11 +49,11 @@ void testConstantFoldShifts() {
|
||||
ExprHandle f = ((a << b) << b) >> c;
|
||||
|
||||
ExprHandle newF = IRSimplifier::simplify(f);
|
||||
EXPECT_NE(newF.AsNode<IntImm>(), nullptr);
|
||||
EXPECT_EQ(newF.AsNode<IntImm>()->value(), 14);
|
||||
ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
|
||||
ASSERT_EQ(newF.AsNode<IntImm>()->value(), 14);
|
||||
|
||||
SimpleIRExprEval eval(newF);
|
||||
EXPECT_EQ(eval.value<int>(), 7 << (4 - 3));
|
||||
ASSERT_EQ(eval.value<int>(), 7 << (4 - 3));
|
||||
}
|
||||
|
||||
void testConstantFoldBitwise() {
|
||||
@ -64,11 +64,11 @@ void testConstantFoldBitwise() {
|
||||
ExprHandle f = (a ^ b) & c;
|
||||
|
||||
ExprHandle newF = IRSimplifier::simplify(f);
|
||||
EXPECT_NE(newF.AsNode<IntImm>(), nullptr);
|
||||
EXPECT_EQ(newF.AsNode<IntImm>()->value(), 37);
|
||||
ASSERT_NE(newF.AsNode<IntImm>(), nullptr);
|
||||
ASSERT_EQ(newF.AsNode<IntImm>()->value(), 37);
|
||||
|
||||
SimpleIRExprEval eval(newF);
|
||||
EXPECT_EQ(eval.value<int>(), (59 ^ 22) & 101);
|
||||
ASSERT_EQ(eval.value<int>(), (59 ^ 22) & 101);
|
||||
}
|
||||
|
||||
void testConstantFoldMultiOp() {
|
||||
@ -82,12 +82,12 @@ void testConstantFoldMultiOp() {
|
||||
ExprHandle fn = ((a / e) - (c + d)) * (f / b);
|
||||
|
||||
ExprHandle newF = IRSimplifier::simplify(fn);
|
||||
EXPECT_NE(newF.AsNode<FloatImm>(), nullptr);
|
||||
ASSERT_NE(newF.AsNode<FloatImm>(), nullptr);
|
||||
|
||||
SimpleIRExprEval eval(newF);
|
||||
SimpleIRExprEval ref(fn);
|
||||
|
||||
EXPECT_EQ(eval.value<float>(), ref.value<float>());
|
||||
ASSERT_EQ(eval.value<float>(), ref.value<float>());
|
||||
}
|
||||
|
||||
void testConstantFoldMinMax() {
|
||||
@ -100,13 +100,13 @@ void testConstantFoldMinMax() {
|
||||
ExprHandle minHandle = Min::make(b, c, true);
|
||||
ExprHandle fn = Max::make(a, minHandle, false);
|
||||
|
||||
EXPECT_EQ(fn.dtype().scalar_type(), ScalarType::Float);
|
||||
ASSERT_EQ(fn.dtype().scalar_type(), ScalarType::Float);
|
||||
|
||||
ExprHandle newF = IRSimplifier::simplify(fn);
|
||||
EXPECT_NE(newF.AsNode<FloatImm>(), nullptr);
|
||||
ASSERT_NE(newF.AsNode<FloatImm>(), nullptr);
|
||||
|
||||
SimpleIRExprEval eval(newF);
|
||||
EXPECT_EQ(eval.value<float>(), 15.f);
|
||||
ASSERT_EQ(eval.value<float>(), 15.f);
|
||||
}
|
||||
|
||||
void testConstantFoldIntrinsics() {
|
||||
@ -122,13 +122,13 @@ void testConstantFoldIntrinsics() {
|
||||
ExprHandle fn = Intrinsics::make(kFabs, rndHandle);
|
||||
|
||||
ExprHandle newF = IRSimplifier::simplify(fn);
|
||||
EXPECT_NE(newF.AsNode<FloatImm>(), nullptr);
|
||||
EXPECT_EQ(newF.AsNode<FloatImm>()->value(), 1);
|
||||
ASSERT_NE(newF.AsNode<FloatImm>(), nullptr);
|
||||
ASSERT_EQ(newF.AsNode<FloatImm>()->value(), 1);
|
||||
|
||||
SimpleIRExprEval eval(newF);
|
||||
SimpleIRExprEval ref(fn);
|
||||
|
||||
EXPECT_EQ(eval.value<float>(), ref.value<float>());
|
||||
ASSERT_EQ(eval.value<float>(), ref.value<float>());
|
||||
}
|
||||
|
||||
void testConstantFoldWithVar() {
|
||||
@ -138,12 +138,12 @@ void testConstantFoldWithVar() {
|
||||
|
||||
ExprHandle newF = IRSimplifier::simplify(body);
|
||||
const Mul* root = newF.AsNode<Mul>();
|
||||
EXPECT_NE(root, nullptr);
|
||||
EXPECT_NE(dynamic_cast<const FloatImm*>(root->rhs()), nullptr);
|
||||
ASSERT_NE(root, nullptr);
|
||||
ASSERT_NE(dynamic_cast<const FloatImm*>(root->rhs()), nullptr);
|
||||
|
||||
ExprHandle result = Let::make(x, ExprHandle(3.f), newF);
|
||||
SimpleIRExprEval eval(result);
|
||||
EXPECT_EQ(eval.value<float>(), 3 * (2 + 4));
|
||||
ASSERT_EQ(eval.value<float>(), 3 * (2 + 4));
|
||||
}
|
||||
|
||||
void testUnFoldableExpr() {
|
||||
@ -154,14 +154,14 @@ void testUnFoldableExpr() {
|
||||
|
||||
ExprHandle newF = IRSimplifier::simplify(body);
|
||||
const Add* root = newF.AsNode<Add>();
|
||||
EXPECT_NE(root, nullptr);
|
||||
EXPECT_EQ(dynamic_cast<const FloatImm*>(root->lhs()), nullptr);
|
||||
EXPECT_EQ(dynamic_cast<const FloatImm*>(root->rhs()), nullptr);
|
||||
ASSERT_NE(root, nullptr);
|
||||
ASSERT_EQ(dynamic_cast<const FloatImm*>(root->lhs()), nullptr);
|
||||
ASSERT_EQ(dynamic_cast<const FloatImm*>(root->rhs()), nullptr);
|
||||
|
||||
ExprHandle result = Let::make(x, ExprHandle(3.f), newF);
|
||||
result = Let::make(y, ExprHandle(2.f), result);
|
||||
SimpleIRExprEval eval(result);
|
||||
EXPECT_EQ(eval.value<float>(), 9 + 10);
|
||||
ASSERT_EQ(eval.value<float>(), 9 + 10);
|
||||
}
|
||||
|
||||
void testHashSimple() {
|
||||
@ -177,12 +177,12 @@ void testHashSimple() {
|
||||
auto hash_a = hasher.hash(a.node());
|
||||
auto hash_f = hasher.hash(f.node());
|
||||
|
||||
EXPECT_NE(hash_x, 0);
|
||||
EXPECT_NE(hash_a, 0);
|
||||
EXPECT_NE(hash_f, 0);
|
||||
EXPECT_NE(hash_x, hash_a);
|
||||
EXPECT_NE(hash_x, hash_f);
|
||||
EXPECT_NE(hash_a, hash_f);
|
||||
ASSERT_NE(hash_x, 0);
|
||||
ASSERT_NE(hash_a, 0);
|
||||
ASSERT_NE(hash_f, 0);
|
||||
ASSERT_NE(hash_x, hash_a);
|
||||
ASSERT_NE(hash_x, hash_f);
|
||||
ASSERT_NE(hash_a, hash_f);
|
||||
}
|
||||
|
||||
void testHashEquivalence() {
|
||||
@ -192,7 +192,7 @@ void testHashEquivalence() {
|
||||
ExprHandle f = (x * y) + (x * y);
|
||||
|
||||
const Add* root = f.AsNode<Add>();
|
||||
EXPECT_NE(root, nullptr);
|
||||
ASSERT_NE(root, nullptr);
|
||||
|
||||
HashProvider hasher;
|
||||
auto hash_f = hasher.hash(f.node());
|
||||
@ -200,26 +200,26 @@ void testHashEquivalence() {
|
||||
auto hash_r = hasher.hash(root->rhs());
|
||||
|
||||
// Root not equal to either branch.
|
||||
EXPECT_NE(hash_f, hash_l);
|
||||
EXPECT_NE(hash_f, hash_r);
|
||||
ASSERT_NE(hash_f, hash_l);
|
||||
ASSERT_NE(hash_f, hash_r);
|
||||
// but branches are equal.
|
||||
EXPECT_EQ(hash_l, hash_r);
|
||||
ASSERT_EQ(hash_l, hash_r);
|
||||
|
||||
// Still equivalent if separate.
|
||||
ExprHandle a(2);
|
||||
ExprHandle f2 = x + a / y;
|
||||
ExprHandle b(2);
|
||||
ExprHandle f3 = x + b / y;
|
||||
EXPECT_EQ(hasher.hash(f2.node()), hasher.hash(f3.node()));
|
||||
ASSERT_EQ(hasher.hash(f2.node()), hasher.hash(f3.node()));
|
||||
|
||||
// Not equivalent if different vars (even with same name).
|
||||
VarHandle z("x", kFloat);
|
||||
ExprHandle f4 = z + b / y;
|
||||
EXPECT_NE(hasher.hash(f2.node()), hasher.hash(f4.node()));
|
||||
ASSERT_NE(hasher.hash(f2.node()), hasher.hash(f4.node()));
|
||||
|
||||
// Intrinsics sanity check.
|
||||
ExprHandle f5 = Intrinsics::make(kSin, x) * Intrinsics::make(kCos, x);
|
||||
EXPECT_NE(hasher.hash(f5.node()), 0);
|
||||
ASSERT_NE(hasher.hash(f5.node()), 0);
|
||||
}
|
||||
|
||||
void testHashEquivalenceAfterFolding() {
|
||||
@ -231,7 +231,7 @@ void testHashEquivalenceAfterFolding() {
|
||||
ExprHandle f = ((a + b) * x) * (c * x);
|
||||
|
||||
const Mul* root = f.AsNode<Mul>();
|
||||
EXPECT_NE(root, nullptr);
|
||||
ASSERT_NE(root, nullptr);
|
||||
|
||||
HashProvider hasher;
|
||||
auto hash_f = hasher.hash(f.node());
|
||||
@ -239,24 +239,24 @@ void testHashEquivalenceAfterFolding() {
|
||||
auto hash_r = hasher.hash(root->rhs());
|
||||
|
||||
// Root not equal to either branch, and branches not equal.
|
||||
EXPECT_NE(hash_f, hash_l);
|
||||
EXPECT_NE(hash_f, hash_r);
|
||||
EXPECT_NE(hash_l, hash_r);
|
||||
ASSERT_NE(hash_f, hash_l);
|
||||
ASSERT_NE(hash_f, hash_r);
|
||||
ASSERT_NE(hash_l, hash_r);
|
||||
|
||||
ExprHandle newF = IRSimplifier::simplify(f);
|
||||
|
||||
const Mul* newRoot = newF.AsNode<Mul>();
|
||||
EXPECT_NE(newRoot, nullptr);
|
||||
ASSERT_NE(newRoot, nullptr);
|
||||
|
||||
auto hash_f_n = hasher.hash(newF.node());
|
||||
auto hash_l_n = hasher.hash(newRoot->lhs());
|
||||
auto hash_r_n = hasher.hash(newRoot->rhs());
|
||||
|
||||
// Root not equal to either branch.
|
||||
EXPECT_NE(hash_f_n, hash_l_n);
|
||||
EXPECT_NE(hash_f_n, hash_r_n);
|
||||
ASSERT_NE(hash_f_n, hash_l_n);
|
||||
ASSERT_NE(hash_f_n, hash_r_n);
|
||||
// but branches are now equal.
|
||||
EXPECT_EQ(hash_l_n, hash_r_n);
|
||||
ASSERT_EQ(hash_l_n, hash_r_n);
|
||||
}
|
||||
|
||||
void testHashDifferenceTypes() {
|
||||
@ -278,7 +278,7 @@ void testHashDifferenceTypes() {
|
||||
// Immediates of different types are not equal.
|
||||
for (unsigned int i = 0; i < immediates.size(); ++i) {
|
||||
for (unsigned int j = i + 1; j < immediates.size(); ++j) {
|
||||
EXPECT_NE(hasher.hash(immediates[i]), hasher.hash(immediates[j]));
|
||||
ASSERT_NE(hasher.hash(immediates[i]), hasher.hash(immediates[j]));
|
||||
}
|
||||
}
|
||||
|
||||
@ -289,7 +289,7 @@ void testHashDifferenceTypes() {
|
||||
ExprHandle ff1 = IRSimplifier::simplify(f1);
|
||||
ExprHandle ff2 = IRSimplifier::simplify(f2);
|
||||
|
||||
EXPECT_EQ(hasher.hash(ff1.node()), hasher.hash(ff2.node()));
|
||||
ASSERT_EQ(hasher.hash(ff1.node()), hasher.hash(ff2.node()));
|
||||
}
|
||||
|
||||
void testHashLargeExpression() {
|
||||
@ -332,15 +332,15 @@ void testHashLargeExpression() {
|
||||
HashProvider hasher;
|
||||
auto hash_r = hasher.hash(if_stmt);
|
||||
// We should not have to do any more work.
|
||||
EXPECT_TRUE(hasher.cachedHash(memcpy_stmt));
|
||||
ASSERT_TRUE(hasher.cachedHash(memcpy_stmt));
|
||||
auto hash_t = hasher.hash(memcpy_stmt);
|
||||
EXPECT_TRUE(hasher.cachedHash(store_ramp_stmt));
|
||||
ASSERT_TRUE(hasher.cachedHash(store_ramp_stmt));
|
||||
auto hash_f = hasher.hash(store_ramp_stmt);
|
||||
|
||||
// Root not equal to either branch, and branches not equal.
|
||||
EXPECT_NE(hash_r, hash_t);
|
||||
EXPECT_NE(hash_r, hash_f);
|
||||
EXPECT_NE(hash_t, hash_f);
|
||||
ASSERT_NE(hash_r, hash_t);
|
||||
ASSERT_NE(hash_r, hash_f);
|
||||
ASSERT_NE(hash_t, hash_f);
|
||||
}
|
||||
|
||||
/// (2.f + x) + 4.f => x + 6.f
|
||||
@ -351,13 +351,13 @@ void testSimplifyAdd() {
|
||||
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
const Add* root = simplified.AsNode<Add>();
|
||||
EXPECT_NE(root, nullptr);
|
||||
ASSERT_NE(root, nullptr);
|
||||
const Var* lhs = dynamic_cast<const Var*>(root->lhs());
|
||||
EXPECT_NE(lhs, nullptr);
|
||||
EXPECT_EQ(lhs->name_hint(), "x");
|
||||
ASSERT_NE(lhs, nullptr);
|
||||
ASSERT_EQ(lhs->name_hint(), "x");
|
||||
const FloatImm* rhs = dynamic_cast<const FloatImm*>(root->rhs());
|
||||
EXPECT_NE(rhs, nullptr);
|
||||
EXPECT_EQ(rhs->value(), 6.f);
|
||||
ASSERT_NE(rhs, nullptr);
|
||||
ASSERT_EQ(rhs->value(), 6.f);
|
||||
}
|
||||
|
||||
/// (2.f - x) - 4.f => -2.f - x
|
||||
@ -368,13 +368,13 @@ void testSimplifySub() {
|
||||
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
const Sub* root = simplified.AsNode<Sub>();
|
||||
EXPECT_NE(root, nullptr);
|
||||
ASSERT_NE(root, nullptr);
|
||||
const FloatImm* lhs = dynamic_cast<const FloatImm*>(root->lhs());
|
||||
EXPECT_NE(lhs, nullptr);
|
||||
EXPECT_EQ(lhs->value(), -2.f);
|
||||
ASSERT_NE(lhs, nullptr);
|
||||
ASSERT_EQ(lhs->value(), -2.f);
|
||||
const Var* rhs = dynamic_cast<const Var*>(root->rhs());
|
||||
EXPECT_NE(rhs, nullptr);
|
||||
EXPECT_EQ(rhs->name_hint(), "x");
|
||||
ASSERT_NE(rhs, nullptr);
|
||||
ASSERT_EQ(rhs->name_hint(), "x");
|
||||
}
|
||||
|
||||
/// 2.f * (1.f - x) - 4.f => -6.f - (x * 2.f)
|
||||
@ -385,18 +385,18 @@ void testSimplifyMultiLayer() {
|
||||
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
const Sub* root = simplified.AsNode<Sub>();
|
||||
EXPECT_NE(root, nullptr);
|
||||
ASSERT_NE(root, nullptr);
|
||||
const FloatImm* lhs = dynamic_cast<const FloatImm*>(root->lhs());
|
||||
EXPECT_NE(lhs, nullptr);
|
||||
EXPECT_EQ(lhs->value(), -6.f);
|
||||
ASSERT_NE(lhs, nullptr);
|
||||
ASSERT_EQ(lhs->value(), -6.f);
|
||||
const Mul* rhs = dynamic_cast<const Mul*>(root->rhs());
|
||||
EXPECT_NE(rhs, nullptr);
|
||||
ASSERT_NE(rhs, nullptr);
|
||||
const Var* varX = dynamic_cast<const Var*>(rhs->lhs());
|
||||
EXPECT_NE(varX, nullptr);
|
||||
EXPECT_EQ(varX->name_hint(), "x");
|
||||
ASSERT_NE(varX, nullptr);
|
||||
ASSERT_EQ(varX->name_hint(), "x");
|
||||
const FloatImm* mulRhs = dynamic_cast<const FloatImm*>(rhs->rhs());
|
||||
EXPECT_NE(mulRhs, nullptr);
|
||||
EXPECT_EQ(mulRhs->value(), 2.f);
|
||||
ASSERT_NE(mulRhs, nullptr);
|
||||
ASSERT_EQ(mulRhs->value(), 2.f);
|
||||
}
|
||||
|
||||
/// 2 * (3 * x) - (x * 4) => x * 2
|
||||
@ -408,13 +408,13 @@ void testSimplifyMultiTerm() {
|
||||
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
const Mul* root = simplified.AsNode<Mul>();
|
||||
EXPECT_NE(root, nullptr);
|
||||
ASSERT_NE(root, nullptr);
|
||||
const Var* lhs = dynamic_cast<const Var*>(root->lhs());
|
||||
EXPECT_NE(lhs, nullptr);
|
||||
EXPECT_EQ(lhs->name_hint(), "x");
|
||||
ASSERT_NE(lhs, nullptr);
|
||||
ASSERT_EQ(lhs->name_hint(), "x");
|
||||
const IntImm* rhs = dynamic_cast<const IntImm*>(root->rhs());
|
||||
EXPECT_NE(rhs, nullptr);
|
||||
EXPECT_EQ(rhs->value(), 2);
|
||||
ASSERT_NE(rhs, nullptr);
|
||||
ASSERT_EQ(rhs->value(), 2);
|
||||
}
|
||||
|
||||
/// 2 * (3 * (f)x) - (x * 4) => x * 2.f
|
||||
@ -426,13 +426,13 @@ void testSimplifyCasts() {
|
||||
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
const Mul* root = simplified.AsNode<Mul>();
|
||||
EXPECT_NE(root, nullptr);
|
||||
ASSERT_NE(root, nullptr);
|
||||
const Var* lhs = dynamic_cast<const Var*>(root->lhs());
|
||||
EXPECT_NE(lhs, nullptr);
|
||||
EXPECT_EQ(lhs->name_hint(), "x");
|
||||
ASSERT_NE(lhs, nullptr);
|
||||
ASSERT_EQ(lhs->name_hint(), "x");
|
||||
const FloatImm* rhs = dynamic_cast<const FloatImm*>(root->rhs());
|
||||
EXPECT_NE(rhs, nullptr);
|
||||
EXPECT_EQ(rhs->value(), 2);
|
||||
ASSERT_NE(rhs, nullptr);
|
||||
ASSERT_EQ(rhs->value(), 2);
|
||||
}
|
||||
|
||||
/// (x + 0) * 1 => x
|
||||
@ -443,8 +443,8 @@ void testSimplifyEliminatesNoOps() {
|
||||
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
const Var* root = simplified.AsNode<Var>();
|
||||
EXPECT_NE(root, nullptr);
|
||||
EXPECT_EQ(root->name_hint(), "x");
|
||||
ASSERT_NE(root, nullptr);
|
||||
ASSERT_EQ(root->name_hint(), "x");
|
||||
}
|
||||
|
||||
/// Cannot simplify this.
|
||||
@ -456,16 +456,16 @@ void testSimplifyMultiVar() {
|
||||
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
const Add* root = simplified.AsNode<Add>();
|
||||
EXPECT_NE(root, nullptr);
|
||||
ASSERT_NE(root, nullptr);
|
||||
const Mul* lhs = dynamic_cast<const Mul*>(root->lhs());
|
||||
EXPECT_NE(lhs, nullptr);
|
||||
ASSERT_NE(lhs, nullptr);
|
||||
const Var* varY = dynamic_cast<const Var*>(lhs->lhs());
|
||||
EXPECT_EQ(varY->name_hint(), "y");
|
||||
ASSERT_EQ(varY->name_hint(), "y");
|
||||
const Mul* rhs = dynamic_cast<const Mul*>(root->rhs());
|
||||
EXPECT_NE(rhs, nullptr);
|
||||
ASSERT_NE(rhs, nullptr);
|
||||
const Var* varX = dynamic_cast<const Var*>(rhs->lhs());
|
||||
EXPECT_NE(varX, nullptr);
|
||||
EXPECT_EQ(varX->name_hint(), "x");
|
||||
ASSERT_NE(varX, nullptr);
|
||||
ASSERT_EQ(varX->name_hint(), "x");
|
||||
}
|
||||
|
||||
/// y + x * 0 => y
|
||||
@ -477,8 +477,8 @@ void testSimplifyEliminatesVar() {
|
||||
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
const Var* root = simplified.AsNode<Var>();
|
||||
EXPECT_NE(root, nullptr);
|
||||
EXPECT_EQ(root->name_hint(), "y");
|
||||
ASSERT_NE(root, nullptr);
|
||||
ASSERT_EQ(root->name_hint(), "y");
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
|
@ -10,34 +10,34 @@ void testTypeTest01() {
|
||||
KernelScope kernel_scope;
|
||||
{
|
||||
Dtype dt1 = kInt;
|
||||
EXPECT_EQ(dt1, kInt);
|
||||
ASSERT_EQ(dt1, kInt);
|
||||
}
|
||||
{
|
||||
Dtype dt2_a(kInt, 8);
|
||||
Dtype dt2_b(kInt, 4);
|
||||
Dtype dt2_c(ScalarType::Int, 8);
|
||||
EXPECT_EQ(dt2_a, dt2_c);
|
||||
EXPECT_NE(dt2_a, dt2_b);
|
||||
ASSERT_EQ(dt2_a, dt2_c);
|
||||
ASSERT_NE(dt2_a, dt2_b);
|
||||
}
|
||||
{
|
||||
EXPECT_EQ(kInt, ToDtype<int>());
|
||||
EXPECT_EQ(kFloat, ToDtype<float>());
|
||||
EXPECT_EQ(kByte, ToDtype<uint8_t>());
|
||||
EXPECT_EQ(kChar, ToDtype<int8_t>());
|
||||
EXPECT_EQ(kShort, ToDtype<int16_t>());
|
||||
EXPECT_EQ(kLong, ToDtype<int64_t>());
|
||||
EXPECT_EQ(kHalf, ToDtype<at::Half>());
|
||||
EXPECT_EQ(kDouble, ToDtype<double>());
|
||||
EXPECT_EQ(kBool, ToDtype<bool>());
|
||||
ASSERT_EQ(kInt, ToDtype<int>());
|
||||
ASSERT_EQ(kFloat, ToDtype<float>());
|
||||
ASSERT_EQ(kByte, ToDtype<uint8_t>());
|
||||
ASSERT_EQ(kChar, ToDtype<int8_t>());
|
||||
ASSERT_EQ(kShort, ToDtype<int16_t>());
|
||||
ASSERT_EQ(kLong, ToDtype<int64_t>());
|
||||
ASSERT_EQ(kHalf, ToDtype<at::Half>());
|
||||
ASSERT_EQ(kDouble, ToDtype<double>());
|
||||
ASSERT_EQ(kBool, ToDtype<bool>());
|
||||
}
|
||||
{
|
||||
Dtype int32x8(kInt, 8);
|
||||
Dtype float32x8(kFloat, 8);
|
||||
EXPECT_NE(int32x8, float32x8);
|
||||
EXPECT_EQ(float32x8, BinaryOpDtype(int32x8, float32x8));
|
||||
EXPECT_EQ(float32x8, BinaryOpDtype(float32x8, int32x8));
|
||||
EXPECT_EQ(int32x8, BinaryOpDtype(int32x8, int32x8));
|
||||
EXPECT_EQ(float32x8, BinaryOpDtype(float32x8, float32x8));
|
||||
ASSERT_NE(int32x8, float32x8);
|
||||
ASSERT_EQ(float32x8, BinaryOpDtype(int32x8, float32x8));
|
||||
ASSERT_EQ(float32x8, BinaryOpDtype(float32x8, int32x8));
|
||||
ASSERT_EQ(int32x8, BinaryOpDtype(int32x8, int32x8));
|
||||
ASSERT_EQ(float32x8, BinaryOpDtype(float32x8, float32x8));
|
||||
}
|
||||
}
|
||||
|
||||
@ -51,7 +51,7 @@ void testTypePropagation() {
|
||||
(x * FloatImm::make(3.f) + FloatImm::make(4.f) * y);
|
||||
ExprHandle e1 = Let::make(x, FloatImm::make(3.f), body);
|
||||
ExprHandle e2 = Let::make(y, FloatImm::make(6.f), e1);
|
||||
EXPECT_EQ(e2.dtype(), kFloat);
|
||||
ASSERT_EQ(e2.dtype(), kFloat);
|
||||
}
|
||||
// Int to bigger int:
|
||||
{
|
||||
@ -62,7 +62,7 @@ void testTypePropagation() {
|
||||
ShortImm::make(2.f) + (x * ShortImm::make(3) + ShortImm::make(4) * y);
|
||||
ExprHandle e1 = Let::make(x, ShortImm::make(3), body);
|
||||
ExprHandle e2 = Let::make(y, LongImm::make(6), e1);
|
||||
EXPECT_EQ(e2.dtype(), kLong);
|
||||
ASSERT_EQ(e2.dtype(), kLong);
|
||||
}
|
||||
// Float to bigger float:
|
||||
{
|
||||
@ -73,7 +73,7 @@ void testTypePropagation() {
|
||||
HalfImm::make(2.f) + (x * HalfImm::make(3) + HalfImm::make(4) * y);
|
||||
ExprHandle e1 = Let::make(x, HalfImm::make(3), body);
|
||||
ExprHandle e2 = Let::make(y, DoubleImm::make(6), e1);
|
||||
EXPECT_EQ(e2.dtype(), kDouble);
|
||||
ASSERT_EQ(e2.dtype(), kDouble);
|
||||
}
|
||||
// Int to Float:
|
||||
{
|
||||
@ -84,7 +84,7 @@ void testTypePropagation() {
|
||||
IntImm::make(2) + (x * IntImm::make(3) + IntImm::make(4) * y);
|
||||
ExprHandle e1 = Let::make(x, FloatImm::make(3.f), body);
|
||||
ExprHandle e2 = Let::make(y, IntImm::make(6), e1);
|
||||
EXPECT_EQ(e2.dtype(), kFloat);
|
||||
ASSERT_EQ(e2.dtype(), kFloat);
|
||||
}
|
||||
// Smaller float, bigger Int:
|
||||
{
|
||||
@ -95,7 +95,7 @@ void testTypePropagation() {
|
||||
HalfImm::make(2) + (x * HalfImm::make(3) + HalfImm::make(4) * y);
|
||||
ExprHandle e1 = Let::make(x, HalfImm::make(3), body);
|
||||
ExprHandle e2 = Let::make(y, LongImm::make(6), e1);
|
||||
EXPECT_EQ(e2.dtype(), kHalf);
|
||||
ASSERT_EQ(e2.dtype(), kHalf);
|
||||
}
|
||||
// Bigger float, smaller Int:
|
||||
{
|
||||
@ -106,7 +106,7 @@ void testTypePropagation() {
|
||||
CharImm::make(2) + (x * CharImm::make(3) + CharImm::make(4) * y);
|
||||
ExprHandle e1 = Let::make(x, CharImm::make(3), body);
|
||||
ExprHandle e2 = Let::make(y, DoubleImm::make(6), e1);
|
||||
EXPECT_EQ(e2.dtype(), kDouble);
|
||||
ASSERT_EQ(e2.dtype(), kDouble);
|
||||
}
|
||||
// Sign change char/byte upgrades to short:
|
||||
{
|
||||
@ -117,7 +117,7 @@ void testTypePropagation() {
|
||||
CharImm::make(2) + (x * CharImm::make(3) + CharImm::make(4) * y);
|
||||
ExprHandle e1 = Let::make(x, CharImm::make(3), body);
|
||||
ExprHandle e2 = Let::make(y, ByteImm::make(6), e1);
|
||||
EXPECT_EQ(e2.dtype(), kShort);
|
||||
ASSERT_EQ(e2.dtype(), kShort);
|
||||
}
|
||||
}
|
||||
} // namespace jit
|
||||
|
@ -9,240 +9,240 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
#define TH_FORALL_TESTS(_) \
|
||||
_(ExprBasicValueTest) \
|
||||
_(ExprBasicValueTest02) \
|
||||
_(ExprLetTest01) \
|
||||
_(ExprLetStmtTest01) \
|
||||
_(ExprLetTest02) \
|
||||
_(ExprIntTest) \
|
||||
_(ExprFloatTest) \
|
||||
_(ExprByteTest) \
|
||||
_(ExprCharTest) \
|
||||
_(ExprShortTest) \
|
||||
_(ExprLongTest) \
|
||||
_(ExprHalfTest) \
|
||||
_(ExprDoubleTest) \
|
||||
_(ExprVectorAdd01) \
|
||||
_(ExprCompareSelectEQ) \
|
||||
_(ExprSubstitute01) \
|
||||
_(ExprMath01) \
|
||||
_(ExprUnaryMath01) \
|
||||
_(ExprBinaryMath01) \
|
||||
_(ExprDynamicShapeAdd) \
|
||||
_(ExprBitwiseOps) \
|
||||
_(IRPrinterBasicValueTest) \
|
||||
_(IRPrinterBasicValueTest02) \
|
||||
_(IRPrinterLetTest01) \
|
||||
_(IRPrinterLetTest02) \
|
||||
_(IRPrinterCastTest) \
|
||||
_(ExprSimple01) \
|
||||
_(ExprLower01) \
|
||||
_(ExprSimple02) \
|
||||
_(ExprSplitWithTailNone) \
|
||||
_(ExprSplitWithMask01) \
|
||||
_(ScheduleBroadcastAddBuffer) \
|
||||
_(ScheduleFunctionCall01) \
|
||||
_(ScheduleInlineFunc01) \
|
||||
_(ScheduleFuserStyle) \
|
||||
_(ScheduleFuserThreeArg) \
|
||||
_(ScheduleDynamicShape2D) \
|
||||
_(TypeTest01) \
|
||||
_(TypePropagation) \
|
||||
_(Cond01) \
|
||||
_(IfThenElse01) \
|
||||
_(IfThenElse02) \
|
||||
_(ATen_cast_Float) \
|
||||
_(ATennegInt) \
|
||||
_(ATennegFloat) \
|
||||
_(ATenaddInt) \
|
||||
_(ATenaddFloat) \
|
||||
_(ATensubInt) \
|
||||
_(ATensubFloat) \
|
||||
_(ATenlerp) \
|
||||
_(ATenaddcmulInt) \
|
||||
_(ATenaddcmulFloat) \
|
||||
_(ATenmulInt) \
|
||||
_(ATenmulFloat) \
|
||||
_(ATendivInt) \
|
||||
_(ATendivFloat) \
|
||||
_(ATenmaxInt) \
|
||||
_(ATenmaxFloat) \
|
||||
_(ATenminInt) \
|
||||
_(ATenminFloat) \
|
||||
_(ATen_sigmoid_backward) \
|
||||
_(ATen_tanh_backward) \
|
||||
_(ATenreciprocal) \
|
||||
_(ATenreluInt) \
|
||||
_(ATenreluFloat) \
|
||||
_(ATenlogFloat) \
|
||||
_(ATenlog10Float) \
|
||||
_(ATenlog2Float) \
|
||||
_(ATenexpFloat) \
|
||||
_(ATenerfFloat) \
|
||||
_(ATencosFloat) \
|
||||
_(ATeneqInt) \
|
||||
_(ATengeInt) \
|
||||
_(ATengtInt) \
|
||||
_(ATenleInt) \
|
||||
_(ATenltInt) \
|
||||
_(ConstantFoldSimple) \
|
||||
_(ConstantFoldTwoLayer) \
|
||||
_(ConstantFoldShifts) \
|
||||
_(ConstantFoldBitwise) \
|
||||
_(ConstantFoldMultiOp) \
|
||||
_(ConstantFoldMinMax) \
|
||||
_(ConstantFoldIntrinsics) \
|
||||
_(ConstantFoldWithVar) \
|
||||
_(UnFoldableExpr) \
|
||||
_(HashSimple) \
|
||||
_(HashEquivalence) \
|
||||
_(HashEquivalenceAfterFolding) \
|
||||
_(HashDifferenceTypes) \
|
||||
_(HashLargeExpression) \
|
||||
_(SimplifyAdd) \
|
||||
_(SimplifySub) \
|
||||
_(SimplifyMultiLayer) \
|
||||
_(SimplifyMultiTerm) \
|
||||
_(SimplifyCasts) \
|
||||
_(SimplifyEliminatesNoOps) \
|
||||
_(SimplifyMultiVar) \
|
||||
_(SimplifyEliminatesVar) \
|
||||
#define TH_FORALL_TENSOREXPR_TESTS(_) \
|
||||
_(ExprBasicValueTest) \
|
||||
_(ExprBasicValueTest02) \
|
||||
_(ExprLetTest01) \
|
||||
_(ExprLetStmtTest01) \
|
||||
_(ExprLetTest02) \
|
||||
_(ExprIntTest) \
|
||||
_(ExprFloatTest) \
|
||||
_(ExprByteTest) \
|
||||
_(ExprCharTest) \
|
||||
_(ExprShortTest) \
|
||||
_(ExprLongTest) \
|
||||
_(ExprHalfTest) \
|
||||
_(ExprDoubleTest) \
|
||||
_(ExprVectorAdd01) \
|
||||
_(ExprCompareSelectEQ) \
|
||||
_(ExprSubstitute01) \
|
||||
_(ExprMath01) \
|
||||
_(ExprUnaryMath01) \
|
||||
_(ExprBinaryMath01) \
|
||||
_(ExprDynamicShapeAdd) \
|
||||
_(ExprBitwiseOps) \
|
||||
_(IRPrinterBasicValueTest) \
|
||||
_(IRPrinterBasicValueTest02) \
|
||||
_(IRPrinterLetTest01) \
|
||||
_(IRPrinterLetTest02) \
|
||||
_(IRPrinterCastTest) \
|
||||
_(ExprSimple01) \
|
||||
_(ExprLower01) \
|
||||
_(ExprSimple02) \
|
||||
_(ExprSplitWithTailNone) \
|
||||
_(ExprSplitWithMask01) \
|
||||
_(ScheduleBroadcastAddBuffer) \
|
||||
_(ScheduleFunctionCall01) \
|
||||
_(ScheduleInlineFunc01) \
|
||||
_(ScheduleFuserStyle) \
|
||||
_(ScheduleFuserThreeArg) \
|
||||
_(ScheduleDynamicShape2D) \
|
||||
_(TypeTest01) \
|
||||
_(TypePropagation) \
|
||||
_(Cond01) \
|
||||
_(IfThenElse01) \
|
||||
_(IfThenElse02) \
|
||||
_(ATen_cast_Float) \
|
||||
_(ATennegInt) \
|
||||
_(ATennegFloat) \
|
||||
_(ATenaddInt) \
|
||||
_(ATenaddFloat) \
|
||||
_(ATensubInt) \
|
||||
_(ATensubFloat) \
|
||||
_(ATenlerp) \
|
||||
_(ATenaddcmulInt) \
|
||||
_(ATenaddcmulFloat) \
|
||||
_(ATenmulInt) \
|
||||
_(ATenmulFloat) \
|
||||
_(ATendivInt) \
|
||||
_(ATendivFloat) \
|
||||
_(ATenmaxInt) \
|
||||
_(ATenmaxFloat) \
|
||||
_(ATenminInt) \
|
||||
_(ATenminFloat) \
|
||||
_(ATen_sigmoid_backward) \
|
||||
_(ATen_tanh_backward) \
|
||||
_(ATenreciprocal) \
|
||||
_(ATenreluInt) \
|
||||
_(ATenreluFloat) \
|
||||
_(ATenlogFloat) \
|
||||
_(ATenlog10Float) \
|
||||
_(ATenlog2Float) \
|
||||
_(ATenexpFloat) \
|
||||
_(ATenerfFloat) \
|
||||
_(ATencosFloat) \
|
||||
_(ATeneqInt) \
|
||||
_(ATengeInt) \
|
||||
_(ATengtInt) \
|
||||
_(ATenleInt) \
|
||||
_(ATenltInt) \
|
||||
_(ConstantFoldSimple) \
|
||||
_(ConstantFoldTwoLayer) \
|
||||
_(ConstantFoldShifts) \
|
||||
_(ConstantFoldBitwise) \
|
||||
_(ConstantFoldMultiOp) \
|
||||
_(ConstantFoldMinMax) \
|
||||
_(ConstantFoldIntrinsics) \
|
||||
_(ConstantFoldWithVar) \
|
||||
_(UnFoldableExpr) \
|
||||
_(HashSimple) \
|
||||
_(HashEquivalence) \
|
||||
_(HashEquivalenceAfterFolding) \
|
||||
_(HashDifferenceTypes) \
|
||||
_(HashLargeExpression) \
|
||||
_(SimplifyAdd) \
|
||||
_(SimplifySub) \
|
||||
_(SimplifyMultiLayer) \
|
||||
_(SimplifyMultiTerm) \
|
||||
_(SimplifyCasts) \
|
||||
_(SimplifyEliminatesNoOps) \
|
||||
_(SimplifyMultiVar) \
|
||||
_(SimplifyEliminatesVar) \
|
||||
_(StmtClone)
|
||||
|
||||
#define TH_FORALL_TESTS_LLVM(_) \
|
||||
_(LLVMByteImmTest) \
|
||||
_(LLVMCharImmTest) \
|
||||
_(LLVMShortImmTest) \
|
||||
_(LLVMIntImmTest) \
|
||||
_(LLVMLongImmTest) \
|
||||
_(LLVMFloatImmTest) \
|
||||
_(LLVMDoubleImmTest) \
|
||||
_(LLVMHalfImmTest) \
|
||||
_(LLVMByteAddTest) \
|
||||
_(LLVMCharAddTest) \
|
||||
_(LLVMShortAddTest) \
|
||||
_(LLVMIntAddTest) \
|
||||
_(LLVMLongAddTest) \
|
||||
_(LLVMFloatAddTest) \
|
||||
_(LLVMDoubleAddTest) \
|
||||
_(LLVMHalfAddTest) \
|
||||
_(LLVMByteSubTest) \
|
||||
_(LLVMCharSubTest) \
|
||||
_(LLVMShortSubTest) \
|
||||
_(LLVMIntSubTest) \
|
||||
_(LLVMLongSubTest) \
|
||||
_(LLVMFloatSubTest) \
|
||||
_(LLVMDoubleSubTest) \
|
||||
_(LLVMHalfSubTest) \
|
||||
_(LLVMByteMulTest) \
|
||||
_(LLVMCharMulTest) \
|
||||
_(LLVMShortMulTest) \
|
||||
_(LLVMIntMulTest) \
|
||||
_(LLVMLongMulTest) \
|
||||
_(LLVMFloatMulTest) \
|
||||
_(LLVMDoubleMulTest) \
|
||||
_(LLVMHalfMulTest) \
|
||||
_(LLVMByteDivTest) \
|
||||
_(LLVMCharDivTest) \
|
||||
_(LLVMShortDivTest) \
|
||||
_(LLVMIntDivTest) \
|
||||
_(LLVMLongDivTest) \
|
||||
_(LLVMFloatDivTest) \
|
||||
_(LLVMDoubleDivTest) \
|
||||
_(LLVMHalfDivTest) \
|
||||
_(LLVMIntToFloatCastTest) \
|
||||
_(LLVMFloatToIntCastTest) \
|
||||
_(LLVMIntToLongCastTest) \
|
||||
_(LLVMByteToCharCastTest) \
|
||||
_(LLVMHalfToLongCastTest) \
|
||||
_(LLVMByteToDoubleCastTest) \
|
||||
_(LLVMLetTest01) \
|
||||
_(LLVMLetTest02) \
|
||||
_(LLVMLetTestMultitype) \
|
||||
_(LLVMBufferTest) \
|
||||
_(LLVMBlockTest) \
|
||||
_(LLVMLoadStoreTest) \
|
||||
_(LLVMVecLoadStoreTest) \
|
||||
_(LLVMVecFloat_acosLane4Test) \
|
||||
_(LLVMVecFloat_asinLane4Test) \
|
||||
_(LLVMVecFloat_atanLane4Test) \
|
||||
_(LLVMVecFloat_coshLane4Test) \
|
||||
_(LLVMVecFloat_sinhLane4Test) \
|
||||
_(LLVMVecFloat_tanhLane4Test) \
|
||||
_(LLVMVecFloat_erfLane4Test) \
|
||||
_(LLVMVecFloat_erfcLane4Test) \
|
||||
_(LLVMVecFloat_expm1Lane4Test) \
|
||||
_(LLVMVecFloat_lgammaLane4Test) \
|
||||
_(LLVMVecFloat_acosLane8Test) \
|
||||
_(LLVMVecFloat_asinLane8Test) \
|
||||
_(LLVMVecFloat_atanLane8Test) \
|
||||
_(LLVMVecFloat_coshLane8Test) \
|
||||
_(LLVMVecFloat_sinhLane8Test) \
|
||||
_(LLVMVecFloat_tanhLane8Test) \
|
||||
_(LLVMVecFloat_erfLane8Test) \
|
||||
_(LLVMVecFloat_erfcLane8Test) \
|
||||
_(LLVMVecFloat_expm1Lane8Test) \
|
||||
_(LLVMVecFloat_lgammaLane8Test) \
|
||||
_(LLVMVecDouble_acosLane2Test) \
|
||||
_(LLVMVecDouble_asinLane2Test) \
|
||||
_(LLVMVecDouble_atanLane2Test) \
|
||||
_(LLVMVecDouble_coshLane2Test) \
|
||||
_(LLVMVecDouble_sinhLane2Test) \
|
||||
_(LLVMVecDouble_tanhLane2Test) \
|
||||
_(LLVMVecDouble_erfLane2Test) \
|
||||
_(LLVMVecDouble_erfcLane2Test) \
|
||||
_(LLVMVecDouble_expm1Lane2Test) \
|
||||
_(LLVMVecDouble_lgammaLane2Test) \
|
||||
_(LLVMVecDouble_acosLane4Test) \
|
||||
_(LLVMVecDouble_asinLane4Test) \
|
||||
_(LLVMVecDouble_atanLane4Test) \
|
||||
_(LLVMVecDouble_coshLane4Test) \
|
||||
_(LLVMVecDouble_sinhLane4Test) \
|
||||
_(LLVMVecDouble_tanhLane4Test) \
|
||||
_(LLVMVecDouble_erfLane4Test) \
|
||||
_(LLVMVecDouble_erfcLane4Test) \
|
||||
_(LLVMVecDouble_expm1Lane4Test) \
|
||||
_(LLVMVecDouble_lgammaLane4Test) \
|
||||
_(LLVMMemcpyTest) \
|
||||
_(LLVMBzeroTest) \
|
||||
_(LLVMElemwiseAdd) \
|
||||
_(LLVMElemwiseAddFloat) \
|
||||
_(LLVMElemwiseLog10Float) \
|
||||
_(LLVMElemwiseMaxInt) \
|
||||
_(LLVMElemwiseMinInt) \
|
||||
_(LLVMElemwiseMaxNumFloat) \
|
||||
_(LLVMElemwiseMaxNumNaNFloat) \
|
||||
_(LLVMElemwiseMinNumFloat) \
|
||||
_(LLVMElemwiseMinNumNaNFloat) \
|
||||
_(LLVMCompareSelectIntEQ) \
|
||||
_(LLVMCompareSelectFloatEQ) \
|
||||
_(LLVMStoreFloat) \
|
||||
_(LLVMSimpleMath01) \
|
||||
_(LLVMComputeMul) \
|
||||
_(LLVMBroadcastAdd) \
|
||||
_(LLVMBitwiseOps) \
|
||||
_(LLVMDynamicShapeAdd) \
|
||||
_(LLVMBindDynamicShapeAdd) \
|
||||
_(LLVMTensorDynamicShapeAdd) \
|
||||
_(LLVMDynamicShape2D) \
|
||||
_(LLVMIfThenElseTest) \
|
||||
#define TH_FORALL_TENSOREXPR_TESTS_LLVM(_) \
|
||||
_(LLVMByteImmTest) \
|
||||
_(LLVMCharImmTest) \
|
||||
_(LLVMShortImmTest) \
|
||||
_(LLVMIntImmTest) \
|
||||
_(LLVMLongImmTest) \
|
||||
_(LLVMFloatImmTest) \
|
||||
_(LLVMDoubleImmTest) \
|
||||
_(LLVMHalfImmTest) \
|
||||
_(LLVMByteAddTest) \
|
||||
_(LLVMCharAddTest) \
|
||||
_(LLVMShortAddTest) \
|
||||
_(LLVMIntAddTest) \
|
||||
_(LLVMLongAddTest) \
|
||||
_(LLVMFloatAddTest) \
|
||||
_(LLVMDoubleAddTest) \
|
||||
_(LLVMHalfAddTest) \
|
||||
_(LLVMByteSubTest) \
|
||||
_(LLVMCharSubTest) \
|
||||
_(LLVMShortSubTest) \
|
||||
_(LLVMIntSubTest) \
|
||||
_(LLVMLongSubTest) \
|
||||
_(LLVMFloatSubTest) \
|
||||
_(LLVMDoubleSubTest) \
|
||||
_(LLVMHalfSubTest) \
|
||||
_(LLVMByteMulTest) \
|
||||
_(LLVMCharMulTest) \
|
||||
_(LLVMShortMulTest) \
|
||||
_(LLVMIntMulTest) \
|
||||
_(LLVMLongMulTest) \
|
||||
_(LLVMFloatMulTest) \
|
||||
_(LLVMDoubleMulTest) \
|
||||
_(LLVMHalfMulTest) \
|
||||
_(LLVMByteDivTest) \
|
||||
_(LLVMCharDivTest) \
|
||||
_(LLVMShortDivTest) \
|
||||
_(LLVMIntDivTest) \
|
||||
_(LLVMLongDivTest) \
|
||||
_(LLVMFloatDivTest) \
|
||||
_(LLVMDoubleDivTest) \
|
||||
_(LLVMHalfDivTest) \
|
||||
_(LLVMIntToFloatCastTest) \
|
||||
_(LLVMFloatToIntCastTest) \
|
||||
_(LLVMIntToLongCastTest) \
|
||||
_(LLVMByteToCharCastTest) \
|
||||
_(LLVMHalfToLongCastTest) \
|
||||
_(LLVMByteToDoubleCastTest) \
|
||||
_(LLVMLetTest01) \
|
||||
_(LLVMLetTest02) \
|
||||
_(LLVMLetTestMultitype) \
|
||||
_(LLVMBufferTest) \
|
||||
_(LLVMBlockTest) \
|
||||
_(LLVMLoadStoreTest) \
|
||||
_(LLVMVecLoadStoreTest) \
|
||||
_(LLVMVecFloat_acosLane4Test) \
|
||||
_(LLVMVecFloat_asinLane4Test) \
|
||||
_(LLVMVecFloat_atanLane4Test) \
|
||||
_(LLVMVecFloat_coshLane4Test) \
|
||||
_(LLVMVecFloat_sinhLane4Test) \
|
||||
_(LLVMVecFloat_tanhLane4Test) \
|
||||
_(LLVMVecFloat_erfLane4Test) \
|
||||
_(LLVMVecFloat_erfcLane4Test) \
|
||||
_(LLVMVecFloat_expm1Lane4Test) \
|
||||
_(LLVMVecFloat_lgammaLane4Test) \
|
||||
_(LLVMVecFloat_acosLane8Test) \
|
||||
_(LLVMVecFloat_asinLane8Test) \
|
||||
_(LLVMVecFloat_atanLane8Test) \
|
||||
_(LLVMVecFloat_coshLane8Test) \
|
||||
_(LLVMVecFloat_sinhLane8Test) \
|
||||
_(LLVMVecFloat_tanhLane8Test) \
|
||||
_(LLVMVecFloat_erfLane8Test) \
|
||||
_(LLVMVecFloat_erfcLane8Test) \
|
||||
_(LLVMVecFloat_expm1Lane8Test) \
|
||||
_(LLVMVecFloat_lgammaLane8Test) \
|
||||
_(LLVMVecDouble_acosLane2Test) \
|
||||
_(LLVMVecDouble_asinLane2Test) \
|
||||
_(LLVMVecDouble_atanLane2Test) \
|
||||
_(LLVMVecDouble_coshLane2Test) \
|
||||
_(LLVMVecDouble_sinhLane2Test) \
|
||||
_(LLVMVecDouble_tanhLane2Test) \
|
||||
_(LLVMVecDouble_erfLane2Test) \
|
||||
_(LLVMVecDouble_erfcLane2Test) \
|
||||
_(LLVMVecDouble_expm1Lane2Test) \
|
||||
_(LLVMVecDouble_lgammaLane2Test) \
|
||||
_(LLVMVecDouble_acosLane4Test) \
|
||||
_(LLVMVecDouble_asinLane4Test) \
|
||||
_(LLVMVecDouble_atanLane4Test) \
|
||||
_(LLVMVecDouble_coshLane4Test) \
|
||||
_(LLVMVecDouble_sinhLane4Test) \
|
||||
_(LLVMVecDouble_tanhLane4Test) \
|
||||
_(LLVMVecDouble_erfLane4Test) \
|
||||
_(LLVMVecDouble_erfcLane4Test) \
|
||||
_(LLVMVecDouble_expm1Lane4Test) \
|
||||
_(LLVMVecDouble_lgammaLane4Test) \
|
||||
_(LLVMMemcpyTest) \
|
||||
_(LLVMBzeroTest) \
|
||||
_(LLVMElemwiseAdd) \
|
||||
_(LLVMElemwiseAddFloat) \
|
||||
_(LLVMElemwiseLog10Float) \
|
||||
_(LLVMElemwiseMaxInt) \
|
||||
_(LLVMElemwiseMinInt) \
|
||||
_(LLVMElemwiseMaxNumFloat) \
|
||||
_(LLVMElemwiseMaxNumNaNFloat) \
|
||||
_(LLVMElemwiseMinNumFloat) \
|
||||
_(LLVMElemwiseMinNumNaNFloat) \
|
||||
_(LLVMCompareSelectIntEQ) \
|
||||
_(LLVMCompareSelectFloatEQ) \
|
||||
_(LLVMStoreFloat) \
|
||||
_(LLVMSimpleMath01) \
|
||||
_(LLVMComputeMul) \
|
||||
_(LLVMBroadcastAdd) \
|
||||
_(LLVMBitwiseOps) \
|
||||
_(LLVMDynamicShapeAdd) \
|
||||
_(LLVMBindDynamicShapeAdd) \
|
||||
_(LLVMTensorDynamicShapeAdd) \
|
||||
_(LLVMDynamicShape2D) \
|
||||
_(LLVMIfThenElseTest) \
|
||||
_(LLVMVectorizerLoadStoreTest)
|
||||
|
||||
#define TH_FORALL_TESTS_CUDA(_) \
|
||||
_(CudaTestVectorAdd01) \
|
||||
_(CudaTestVectorAdd02) \
|
||||
_(CudaDynamicShape2D) \
|
||||
_(CudaTestRand01) \
|
||||
#define TH_FORALL_TENSOREXPR_TESTS_CUDA(_) \
|
||||
_(CudaTestVectorAdd01) \
|
||||
_(CudaTestVectorAdd02) \
|
||||
_(CudaDynamicShape2D) \
|
||||
_(CudaTestRand01) \
|
||||
_(CudaDynamicShapeSplit)
|
||||
|
||||
#define DECLARE_TENSOREXPR_TEST(name) void test##name();
|
||||
TH_FORALL_TESTS(DECLARE_TENSOREXPR_TEST)
|
||||
TH_FORALL_TENSOREXPR_TESTS(DECLARE_TENSOREXPR_TEST)
|
||||
#ifdef TORCH_ENABLE_LLVM
|
||||
TH_FORALL_TESTS_LLVM(DECLARE_TENSOREXPR_TEST)
|
||||
TH_FORALL_TENSOREXPR_TESTS_LLVM(DECLARE_TENSOREXPR_TEST)
|
||||
#endif
|
||||
#ifdef USE_CUDA
|
||||
TH_FORALL_TESTS_CUDA(DECLARE_TENSOREXPR_TEST)
|
||||
TH_FORALL_TENSOREXPR_TESTS_CUDA(DECLARE_TENSOREXPR_TEST)
|
||||
#endif
|
||||
#undef DECLARE_TENSOREXPR_TEST
|
||||
|
||||
|
@ -3066,6 +3066,22 @@ graph(%Ra, %Rb):
|
||||
torch._C._jit_run_cpp_tests(run_cuda=True)
|
||||
tests_setup.shutdown()
|
||||
|
||||
@unittest.skipIf(IS_SANDCASTLE, "gtest runs these in sandcastle")
|
||||
@unittest.skipIf(RUN_CUDA, "covered by test_tensorexpr_cuda")
|
||||
@unittest.skipIf(IS_WINDOWS, "enable on windows")
|
||||
@unittest.skipIf(not torch._C._has_tensorexpr_cpp_tests(), "Tests were not built, use BUILD_TEST=1")
|
||||
@skipIfRocm
|
||||
def test_tensorexpr_cpp(self):
|
||||
torch._C._run_tensorexpr_cpp_tests(run_cuda=False)
|
||||
|
||||
@unittest.skipIf(IS_SANDCASTLE, "gtest runs these in sandcastle")
|
||||
@unittest.skipIf(not RUN_CUDA, "covered by test_tensorexpr")
|
||||
@unittest.skipIf(IS_WINDOWS, "enable on windows")
|
||||
@unittest.skipIf(not torch._C._has_tensorexpr_cpp_tests(), "Tests were not built, use BUILD_TEST=1")
|
||||
@skipIfRocm
|
||||
def test_tensorexpr_cpp_cuda(self):
|
||||
torch._C._run_tensorexpr_cpp_tests(run_cuda=True)
|
||||
|
||||
def test_batchnorm(self):
|
||||
x = torch.ones(2, 2, 2, 2)
|
||||
g, outputs, inputs = torch.jit._get_trace_graph(nn.BatchNorm2d(2), x,
|
||||
|
@ -401,9 +401,11 @@ def add_torch_libs():
|
||||
"torch/csrc/utils/tensor_numpy.cpp",
|
||||
"torch/csrc/utils/tensor_types.cpp",
|
||||
"test/cpp/jit/torch_python_test.cpp",
|
||||
"test/cpp/tensorexpr/padded_buffer.cpp",
|
||||
]
|
||||
|
||||
libtorch_python_sources.extend(native.glob(["test/cpp/jit/test_*.cpp"]))
|
||||
libtorch_python_sources.extend(native.glob(["test/cpp/tensorexpr/test_*.cpp"]))
|
||||
|
||||
compiler_flags_cpu = [
|
||||
"-DUSE_C10D",
|
||||
@ -433,7 +435,7 @@ def add_torch_libs():
|
||||
"-Wno-unknown-pragmas",
|
||||
],
|
||||
},
|
||||
"headers": native.glob(["torch/csrc/**/*.h", "torch/csrc/generic/*.cpp", "test/cpp/jit/*.h"]),
|
||||
"headers": native.glob(["torch/csrc/**/*.h", "torch/csrc/generic/*.cpp", "test/cpp/jit/*.h", "test/cpp/tensorexpr/*.h"]),
|
||||
}
|
||||
propagated_pp_flags = [
|
||||
"-Icaffe2",
|
||||
|
@ -118,6 +118,7 @@ if(BUILD_TEST AND NOT USE_ROCM)
|
||||
add_definitions(-DBUILDING_TESTS)
|
||||
list(APPEND TORCH_PYTHON_SRCS
|
||||
${TORCH_ROOT}/test/cpp/jit/torch_python_test.cpp
|
||||
${TENSOREXPR_TEST_SRCS_WITH_PADDED}
|
||||
${JIT_TEST_SRCS}
|
||||
)
|
||||
endif()
|
||||
|
@ -103,6 +103,7 @@ bool loadPythonClasses() {
|
||||
|
||||
#if !defined(__HIP_PLATFORM_HCC__)
|
||||
TORCH_API void runJITCPPTests(bool runCuda);
|
||||
TORCH_API void runTENSOREXPRCPPTests(bool runCuda);
|
||||
#endif
|
||||
|
||||
void initJITBindings(PyObject* module) {
|
||||
@ -328,9 +329,23 @@ void initJITBindings(PyObject* module) {
|
||||
},
|
||||
py::arg("run_cuda"))
|
||||
.def("_jit_has_cpp_tests", []() { return true; })
|
||||
.def(
|
||||
"_run_tensorexpr_cpp_tests",
|
||||
[](bool runCuda) {
|
||||
// We have to release the GIL inside this method, because if we
|
||||
// happen to initialize the autograd engine in these tests, the
|
||||
// newly spawned worker threads will try to initialize their
|
||||
// PyThreadState*, and they need the GIL for this.
|
||||
pybind11::gil_scoped_release _no_gil;
|
||||
return runTENSOREXPRCPPTests(runCuda);
|
||||
},
|
||||
py::arg("run_cuda"))
|
||||
.def("_has_tensorexpr_cpp_tests", []() { return true; })
|
||||
#else
|
||||
.def("_jit_run_cpp_tests", []() { throw std::exception(); })
|
||||
.def("_jit_has_cpp_tests", []() { return false; })
|
||||
.def("_run_tensorexpr_cpp_tests", []() { throw std::exception(); })
|
||||
.def("_has_tensorexpr_cpp_tests", []() { return false; })
|
||||
#endif
|
||||
.def(
|
||||
"_jit_flatten",
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Logging.h>
|
||||
#include <torch/csrc/jit/tensorexpr/buffer.h>
|
||||
#include <torch/csrc/jit/tensorexpr/codegen.h>
|
||||
@ -98,6 +99,26 @@ inline bool mod_value(bool lhs, bool rhs) {
|
||||
throw std::runtime_error("Attempted modulus of bool");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline typename std::enable_if<std::is_integral<T>::value, T>::type div_value(
|
||||
T lhs,
|
||||
T rhs) {
|
||||
TORCH_CHECK(rhs != 0, "Division by zero");
|
||||
return lhs / rhs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline typename std::enable_if<std::is_floating_point<T>::value, T>::
|
||||
type __ubsan_ignore_float_divide_by_zero__
|
||||
div_value(T lhs, T rhs) {
|
||||
return lhs / rhs;
|
||||
}
|
||||
|
||||
inline bool div_value(bool lhs, bool rhs) {
|
||||
LOG(FATAL) << "Attempted division of bool";
|
||||
return false;
|
||||
}
|
||||
|
||||
class SimpleIREvaluator : public CodeGen, public IRVisitor {
|
||||
public:
|
||||
using CodeGen::CodeGen;
|
||||
@ -205,7 +226,7 @@ class SimpleIREvaluator : public CodeGen, public IRVisitor {
|
||||
result_v[i] = lhs_v[i] * rhs_v[i];
|
||||
break;
|
||||
case IRNodeType::kDiv:
|
||||
result_v[i] = lhs_v[i] / rhs_v[i];
|
||||
result_v[i] = div_value(lhs_v[i], rhs_v[i]);
|
||||
break;
|
||||
case IRNodeType::kMod:
|
||||
result_v[i] = mod_value(lhs_v[i], rhs_v[i]);
|
||||
|
Reference in New Issue
Block a user