Files
pytorch/test/cpp/jit/test_union.cpp
Scott Wolchok 2d885ab73d [jit] Reduce refcounting of Types (#65345)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65345

FooType::get() can return a const reference. Inconveniently, converting shared_ptr<FooType> to shared_ptr<Type> requires a copy & refcount bump, so to properly take advantage of this in unshapedType() we need to take a const Type& in isSubtypeOf(), which is good practice anyway -- don't require a shared_ptr if you don't need to take ownership.
ghstack-source-id: 140044165

Test Plan:
CI

perf says c10::unshapedType time decreased from 2.8% to 2.2% during static runtime startup, though I expect this to be generally beneficial.

Reviewed By: hlu1

Differential Revision: D31027361

fbshipit-source-id: 676feb81db9f74ad7b8651d8774f4ecb4cfa6ab8
2021-10-08 09:03:04 -07:00

150 lines
4.7 KiB
C++

#include <gtest/gtest.h>
#include <ATen/core/jit_type.h>
#include <test/cpp/jit/test_utils.h>
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
class UnionTypeTest : public ::testing::Test {
public:
// None
const TypePtr none = NoneType::get();
// List[str]
const TypePtr l1 = ListType::ofStrings();
// Optional[int]
const TypePtr opt1 = OptionalType::create(IntType::get());
// Optional[float]
const TypePtr opt2 = OptionalType::create(FloatType::get());
// Optional[List[str]]
const TypePtr opt3 = OptionalType::create(ListType::ofStrings());
// Tuple[Optional[int], int]
const TypePtr tup1 =
TupleType::create({OptionalType::create(IntType::get()), IntType::get()});
// Tuple[int, int]
const TypePtr tup2 = TupleType::create({IntType::get(), IntType::get()});
bool hasType(UnionTypePtr u, TypePtr t) {
auto res = std::find(u->getTypes().begin(), u->getTypes().end(), t);
return res != u->getTypes().end();
}
};
TEST_F(UnionTypeTest, UnionOperatorEquals) {
const UnionTypePtr u1 = UnionType::create({l1, tup2, StringType::get()});
// Same thing, but using different TypePtrs
const TypePtr l1_ = ListType::ofStrings();
const TypePtr tup2_ = TupleType::create({IntType::get(), IntType::get()});
const UnionTypePtr u2 = UnionType::create({l1_, tup2_, StringType::get()});
ASSERT_TRUE(*u1 == *u2);
}
TEST_F(UnionTypeTest, UnionCreate_OptionalT1AndOptionalT2) {
// Goal: Union[int, float, None]
const UnionTypePtr u = UnionType::create({opt1, opt2});
ASSERT_EQ(u->getTypes().size(), 3);
ASSERT_TRUE(UnionTypeTest::hasType(u, IntType::get()));
ASSERT_TRUE(UnionTypeTest::hasType(u, FloatType::get()));
ASSERT_TRUE(UnionTypeTest::hasType(u, NoneType::get()));
}
TEST_F(UnionTypeTest, UnionCreate_OptionalTAndT) {
// Goal: Union[int, None]
const UnionTypePtr u = UnionType::create({opt1, IntType::get()});
ASSERT_EQ(u->getTypes().size(), 2);
ASSERT_TRUE(UnionTypeTest::hasType(u, IntType::get()));
ASSERT_TRUE(UnionTypeTest::hasType(u, NoneType::get()));
}
TEST_F(UnionTypeTest, UnionCreate_TupleWithSubtypingRelationship) {
// Goal: Union[Tuple[Optional[int], int], str]
const UnionTypePtr u = UnionType::create({StringType::get(), tup1, tup2});
ASSERT_EQ(u->getTypes().size(), 2);
ASSERT_TRUE(UnionTypeTest::hasType(u, StringType::get()));
ASSERT_TRUE(UnionTypeTest::hasType(u, tup1));
}
TEST_F(UnionTypeTest, UnionCreate_ContainerTAndT) {
// Goal: Union[List[str], str]
const UnionTypePtr u = UnionType::create({l1, StringType::get()});
ASSERT_EQ(u->getTypes().size(), 2);
ASSERT_TRUE(UnionTypeTest::hasType(u, StringType::get()));
ASSERT_TRUE(UnionTypeTest::hasType(u, ListType::ofStrings()));
}
TEST_F(UnionTypeTest, UnionCreate_OptionalContainerTAndContainerTAndT) {
// Goal: Union[List[str], None, str]
const UnionTypePtr u = UnionType::create({l1, opt3, StringType::get()});
ASSERT_EQ(u->getTypes().size(), 3);
ASSERT_TRUE(UnionTypeTest::hasType(u, StringType::get()));
ASSERT_TRUE(UnionTypeTest::hasType(u, ListType::ofStrings()));
}
TEST_F(UnionTypeTest, Subtyping_NumberType) {
// Union[int, float, Complex]
const UnionTypePtr union1 =
UnionType::create({IntType::get(), FloatType::get(), ComplexType::get()});
// Union[int, float, Complex, None]
const UnionTypePtr union2 = UnionType::create(
{IntType::get(), FloatType::get(), ComplexType::get(), NoneType::get()});
const NumberTypePtr num = NumberType::get();
ASSERT_TRUE(num->isSubtypeOf(*union1));
ASSERT_TRUE(union1->isSubtypeOf(*num));
ASSERT_TRUE(*num == *union1);
ASSERT_TRUE(num->isSubtypeOf(*union2));
ASSERT_FALSE(union2->isSubtypeOf(*num));
ASSERT_FALSE(*num == *union2);
}
TEST_F(UnionTypeTest, Subtyping_OptionalType) {
// Union[int, None]
const UnionTypePtr union1 =
UnionType::create({IntType::get(), NoneType::get()});
// Union[int, str, None]
const UnionTypePtr union2 =
UnionType::create({IntType::get(), StringType::get(), NoneType::get()});
// Union[int, str, List[str]]
const UnionTypePtr union3 = UnionType::create(
{IntType::get(), StringType::get(), ListType::ofStrings()});
ASSERT_TRUE(none->isSubtypeOf(opt1));
ASSERT_TRUE(none->isSubtypeOf(union1));
ASSERT_TRUE(none->isSubtypeOf(union2));
ASSERT_FALSE(none->isSubtypeOf(union3));
ASSERT_FALSE(opt1->isSubtypeOf(none));
ASSERT_TRUE(opt1->isSubtypeOf(union1));
ASSERT_TRUE(opt1->isSubtypeOf(union2));
ASSERT_FALSE(opt1->isSubtypeOf(union3));
ASSERT_FALSE(union1->isSubtypeOf(none));
ASSERT_TRUE(union1->isSubtypeOf(opt1));
ASSERT_TRUE(union1->isSubtypeOf(union2));
ASSERT_FALSE(union1->isSubtypeOf(union3));
ASSERT_FALSE(union2->isSubtypeOf(union1));
}
} // namespace jit
} // namespace torch