mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Summary: This PR is created to replace https://github.com/pytorch/pytorch/pull/53180 PR stack, which has all the review discussions. Reason for needing a replacement is due to a messy Sandcastle issue. Pull Request resolved: https://github.com/pytorch/pytorch/pull/64234 Reviewed By: gmagogsfm Differential Revision: D30656444 Pulled By: ansley fbshipit-source-id: 77536c8bcc88162e2c72636026ca3c16891d669a
150 lines
4.7 KiB
C++
150 lines
4.7 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <ATen/core/jit_type.h>
|
|
#include <test/cpp/jit/test_utils.h>
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
class UnionTypeTest : public ::testing::Test {
|
|
public:
|
|
// None
|
|
const TypePtr none = NoneType::get();
|
|
|
|
// List[str]
|
|
const TypePtr l1 = ListType::ofStrings();
|
|
|
|
// Optional[int]
|
|
const TypePtr opt1 = OptionalType::create(IntType::get());
|
|
|
|
// Optional[float]
|
|
const TypePtr opt2 = OptionalType::create(FloatType::get());
|
|
|
|
// Optional[List[str]]
|
|
const TypePtr opt3 = OptionalType::create(ListType::ofStrings());
|
|
|
|
// Tuple[Optional[int], int]
|
|
const TypePtr tup1 =
|
|
TupleType::create({OptionalType::create(IntType::get()), IntType::get()});
|
|
|
|
// Tuple[int, int]
|
|
const TypePtr tup2 = TupleType::create({IntType::get(), IntType::get()});
|
|
|
|
bool hasType(UnionTypePtr u, TypePtr t) {
|
|
auto res = std::find(u->getTypes().begin(), u->getTypes().end(), t);
|
|
return res != u->getTypes().end();
|
|
}
|
|
};
|
|
|
|
TEST_F(UnionTypeTest, UnionOperatorEquals) {
|
|
const UnionTypePtr u1 = UnionType::create({l1, tup2, StringType::get()});
|
|
|
|
// Same thing, but using different TypePtrs
|
|
const TypePtr l1_ = ListType::ofStrings();
|
|
const TypePtr tup2_ = TupleType::create({IntType::get(), IntType::get()});
|
|
const UnionTypePtr u2 = UnionType::create({l1_, tup2_, StringType::get()});
|
|
|
|
ASSERT_TRUE(*u1 == *u2);
|
|
}
|
|
|
|
TEST_F(UnionTypeTest, UnionCreate_OptionalT1AndOptionalT2) {
|
|
// Goal: Union[int, float, None]
|
|
const UnionTypePtr u = UnionType::create({opt1, opt2});
|
|
|
|
ASSERT_EQ(u->getTypes().size(), 3);
|
|
ASSERT_TRUE(UnionTypeTest::hasType(u, IntType::get()));
|
|
ASSERT_TRUE(UnionTypeTest::hasType(u, FloatType::get()));
|
|
ASSERT_TRUE(UnionTypeTest::hasType(u, NoneType::get()));
|
|
}
|
|
|
|
TEST_F(UnionTypeTest, UnionCreate_OptionalTAndT) {
|
|
// Goal: Union[int, None]
|
|
const UnionTypePtr u = UnionType::create({opt1, IntType::get()});
|
|
|
|
ASSERT_EQ(u->getTypes().size(), 2);
|
|
ASSERT_TRUE(UnionTypeTest::hasType(u, IntType::get()));
|
|
ASSERT_TRUE(UnionTypeTest::hasType(u, NoneType::get()));
|
|
}
|
|
|
|
TEST_F(UnionTypeTest, UnionCreate_TupleWithSubtypingRelationship) {
|
|
// Goal: Union[Tuple[Optional[int], int], str]
|
|
const UnionTypePtr u = UnionType::create({StringType::get(), tup1, tup2});
|
|
|
|
ASSERT_EQ(u->getTypes().size(), 2);
|
|
ASSERT_TRUE(UnionTypeTest::hasType(u, StringType::get()));
|
|
ASSERT_TRUE(UnionTypeTest::hasType(u, tup1));
|
|
}
|
|
|
|
TEST_F(UnionTypeTest, UnionCreate_ContainerTAndT) {
|
|
// Goal: Union[List[str], str]
|
|
const UnionTypePtr u = UnionType::create({l1, StringType::get()});
|
|
|
|
ASSERT_EQ(u->getTypes().size(), 2);
|
|
ASSERT_TRUE(UnionTypeTest::hasType(u, StringType::get()));
|
|
ASSERT_TRUE(UnionTypeTest::hasType(u, ListType::ofStrings()));
|
|
}
|
|
|
|
TEST_F(UnionTypeTest, UnionCreate_OptionalContainerTAndContainerTAndT) {
|
|
// Goal: Union[List[str], None, str]
|
|
const UnionTypePtr u = UnionType::create({l1, opt3, StringType::get()});
|
|
|
|
ASSERT_EQ(u->getTypes().size(), 3);
|
|
ASSERT_TRUE(UnionTypeTest::hasType(u, StringType::get()));
|
|
ASSERT_TRUE(UnionTypeTest::hasType(u, ListType::ofStrings()));
|
|
}
|
|
|
|
TEST_F(UnionTypeTest, Subtyping_NumberType) {
|
|
// Union[int, float, Complex]
|
|
const UnionTypePtr union1 =
|
|
UnionType::create({IntType::get(), FloatType::get(), ComplexType::get()});
|
|
|
|
// Union[int, float, Complex, None]
|
|
const UnionTypePtr union2 = UnionType::create(
|
|
{IntType::get(), FloatType::get(), ComplexType::get(), NoneType::get()});
|
|
|
|
const NumberTypePtr num = NumberType::get();
|
|
|
|
ASSERT_TRUE(num->isSubtypeOf(union1));
|
|
ASSERT_TRUE(union1->isSubtypeOf(num));
|
|
ASSERT_TRUE(*num == *union1);
|
|
|
|
ASSERT_TRUE(num->isSubtypeOf(union2));
|
|
ASSERT_FALSE(union2->isSubtypeOf(num));
|
|
ASSERT_FALSE(*num == *union2);
|
|
}
|
|
|
|
TEST_F(UnionTypeTest, Subtyping_OptionalType) {
|
|
// Union[int, None]
|
|
const UnionTypePtr union1 =
|
|
UnionType::create({IntType::get(), NoneType::get()});
|
|
|
|
// Union[int, str, None]
|
|
const UnionTypePtr union2 =
|
|
UnionType::create({IntType::get(), StringType::get(), NoneType::get()});
|
|
|
|
// Union[int, str, List[str]]
|
|
const UnionTypePtr union3 = UnionType::create(
|
|
{IntType::get(), StringType::get(), ListType::ofStrings()});
|
|
|
|
ASSERT_TRUE(none->isSubtypeOf(opt1));
|
|
ASSERT_TRUE(none->isSubtypeOf(union1));
|
|
ASSERT_TRUE(none->isSubtypeOf(union2));
|
|
ASSERT_FALSE(none->isSubtypeOf(union3));
|
|
|
|
ASSERT_FALSE(opt1->isSubtypeOf(none));
|
|
ASSERT_TRUE(opt1->isSubtypeOf(union1));
|
|
ASSERT_TRUE(opt1->isSubtypeOf(union2));
|
|
ASSERT_FALSE(opt1->isSubtypeOf(union3));
|
|
|
|
ASSERT_FALSE(union1->isSubtypeOf(none));
|
|
ASSERT_TRUE(union1->isSubtypeOf(opt1));
|
|
ASSERT_TRUE(union1->isSubtypeOf(union2));
|
|
ASSERT_FALSE(union1->isSubtypeOf(union3));
|
|
|
|
ASSERT_FALSE(union2->isSubtypeOf(union1));
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|