Files
pytorch/test/cpp/tensorexpr/test_utils.h
PyTorch MergeBot 13398dab79 Revert "Remove tensorexpr tests (#158928)"
This reverts commit a3f9f79f591102afa93145bb67dc7e34df44f9a4.

Reverted https://github.com/pytorch/pytorch/pull/158928 on behalf of https://github.com/clee2000 due to Theres still some references to the things removed in this PR in test.sh, the jobs on this PR are failing because of that but log classifier is probably pointing to a wrong line, should be an easy fix tho ([comment](https://github.com/pytorch/pytorch/pull/158928#issuecomment-3114873706))
2025-07-24 20:45:30 +00:00

79 lines
2.6 KiB
C++

#pragma once
#include <memory>
#include <vector>
#include <test/cpp/tensorexpr/test_base.h>
#include <torch/csrc/jit/tensorexpr/fwd_decls.h>
#include <torch/csrc/jit/testing/file_check.h>
namespace torch {
namespace jit {
using namespace torch::jit::tensorexpr;
#define IS_NODE(T, node) \
{ \
auto node_ = to<T>(node); \
ASSERT_NE(nullptr, node_); \
}
#define IS_NODE_WITH_NAME(T, node, name) \
auto name = to<T>(node); \
ASSERT_NE(nullptr, name);
#define IS_NODE_WITH_NAME_AND_CAST(T, node, name, Type) \
NodePtr<T> name = nullptr; \
{ \
auto node_ = to<Cast>(node); \
ASSERT_NE(nullptr, node_); \
ASSERT_EQ(node_->dtype().scalar_type(), ScalarType::Type); \
name = to<T>(node_->src_value()); \
} \
ASSERT_NE(nullptr, name);
#define IS_IMM_WITH_VAL(T, node, val) \
{ \
auto node_ = to<T##Imm>(node); \
ASSERT_NE(nullptr, node_); \
ASSERT_EQ(node_->value(), val); \
}
#define IS_VAR_WITH_NAME(node, name) \
{ \
auto node_ = to<Var>(node); \
ASSERT_NE(nullptr, node_); \
ASSERT_EQ(node_->name_hint(), name); \
}
#define IS_BINOP_W_VARS(T, node, name, v1, v2) \
NodePtr<T> name = nullptr; \
{ \
name = to<T>(node); \
ASSERT_NE(nullptr, name); \
IS_VAR_WITH_NAME(name->lhs(), v1); \
IS_VAR_WITH_NAME(name->rhs(), v2); \
}
#define IS_BINOP_W_CONST(T, node, name, v, c) \
NodePtr<T> name = nullptr; \
{ \
name = to<T>(node); \
ASSERT_NE(nullptr, name); \
IS_VAR_WITH_NAME(name->lhs(), v); \
IS_IMM_WITH_VAL(Int, name->rhs(), c); \
}
#define IS_RAND(node) \
{ \
auto node_ = to<Intrinsics>(node); \
ASSERT_NE(nullptr, node_); \
ASSERT_EQ(node_->op_type(), kRand); \
}
void checkIR(StmtPtr s, const std::string& pattern);
void checkExprIR(ExprPtr e, const std::string& pattern);
void checkExprIR(const ExprHandle& e, const std::string& pattern);
} // namespace jit
} // namespace torch