[jit] Split DynamicType conformance test into smaller pieces. (#71275)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71275

Currently it's taking more than 10 minutes to run the conformance test. Instead we should use parametrized test to shard into test segments so that they can run in parallel.
ghstack-source-id: 146990608

Test Plan:
```
[zhxchen17@devbig560.ftw3 /data/users/zhxchen17/fbsource/fbcode] buck test mode/dev-tsan //caffe2/test/cpp/jit:jit -- -r 'LiteInterpreterDynamicTypeTestFixture'
Building... 34.9 sec (99%) 12110/12111 jobs, 0/12111 updated
Tpx test run coordinator for Facebook. See https://fburl.com/tpx for details.
Running with tpx session id: ebea52b3-7c7f-46be-9f69-18e2e7b040cc
Trace available for this run at /tmp/tpx-20220113-113635.717778/trace.log
RemoteExecution session id: reSessionID-ebea52b3-7c7f-46be-9f69-18e2e7b040cc-tpx
Started reporting to test run: https://www.internalfb.com/intern/testinfra/testrun/4222124735827748
    ✓ ListingSuccess: caffe2/test/cpp/jit:jit : 431 tests discovered (11.173)
    ✓ Pass: caffe2/test/cpp/jit:jit - Conformance/LiteInterpreterDynamicTypeTestFixture.Conformance/0 (51.331)
    ✓ Pass: caffe2/test/cpp/jit:jit - Conformance/LiteInterpreterDynamicTypeTestFixture.Conformance/1 (65.614)
    ✓ Pass: caffe2/test/cpp/jit:jit - Conformance/LiteInterpreterDynamicTypeTestFixture.Conformance/3 (76.875)
    ✓ Pass: caffe2/test/cpp/jit:jit - Conformance/LiteInterpreterDynamicTypeTestFixture.Conformance/5 (77.271)
    ✓ Pass: caffe2/test/cpp/jit:jit - Conformance/LiteInterpreterDynamicTypeTestFixture.Conformance/4 (78.871)
    ✓ Pass: caffe2/test/cpp/jit:jit - Conformance/LiteInterpreterDynamicTypeTestFixture.Conformance/6 (78.984)
    ✓ Pass: caffe2/test/cpp/jit:jit - Conformance/LiteInterpreterDynamicTypeTestFixture.Conformance/7 (84.068)
    ✓ Pass: caffe2/test/cpp/jit:jit - Conformance/LiteInterpreterDynamicTypeTestFixture.Conformance/2 (85.198)
    ✓ Pass: caffe2/test/cpp/jit:jit - Conformance/LiteInterpreterDynamicTypeTestFixture.Conformance/8 (88.815)
    ✓ Pass: caffe2/test/cpp/jit:jit - Conformance/LiteInterpreterDynamicTypeTestFixture.Conformance/9 (90.332)
Summary
  Pass: 10
  ListingSuccess: 1
If you need help understanding your runs, please follow the wiki: https://fburl.com/posting_in_tpx_users
Finished test run: https://www.internalfb.com/intern/testinfra/testrun/4222124735827748
```

Reviewed By: qihqi

Differential Revision: D33570442

fbshipit-source-id: 5c49e03b0f88068d444c84b4adeaaf45433ce1fa
This commit is contained in:
Zhengxu Chen
2022-01-13 18:21:14 -08:00
committed by Facebook GitHub Bot
parent 81f693d509
commit 5f2b4be3b9

View File

@ -2055,6 +2055,67 @@ void enumerateTupleType(
}
}
class LiteInterpreterDynamicTypeTestFixture
: public ::testing::TestWithParam<size_t> {
protected:
void SetUp() override {
cu = std::make_shared<CompilationUnit>();
std::vector<TypePtr> keyTypes = {
AnyType::get(),
IntType::get(),
BoolType::get(),
FloatType::get(),
ComplexType::get(),
StringType::get(),
TensorType::get(),
DeviceObjType::get(),
};
types = {
NoneType::get(),
NumberType::get(),
ClassType::create("__torch__.TestClass1", cu),
ClassType::create("__torch__.TestClass2", cu),
AnyListType::get(),
AnyTupleType::get(),
StreamObjType::get(),
CapsuleType::get(),
GeneratorType::get(),
StorageType::get(),
VarType::create("t"),
VarType::create("v"),
AnyClassType::get()};
std::copy(keyTypes.begin(), keyTypes.end(), back_inserter(types));
auto expandTypes = [&](size_t tupleSize) {
std::vector<TypePtr> nested;
for (const auto& type : types) {
if (!(type == AnyType::get())) {
nested.emplace_back(ListType::create(type));
if (!(type == NoneType::get() ||
type->kind() == OptionalType::Kind)) {
nested.emplace_back(OptionalType::create(type));
}
}
for (const auto& keyType : keyTypes) {
nested.emplace_back(DictType::create(keyType, type));
}
}
std::vector<TypePtr> tmp;
enumerateTupleType(tupleSize, tmp, types, nested);
std::move(
std::begin(nested), std::end(nested), std::back_inserter(types));
};
expandTypes(1);
expandTypes(1);
}
std::shared_ptr<CompilationUnit> cu;
std::vector<TypePtr> types;
public:
static constexpr size_t kNumSplits = 10;
};
constexpr size_t LiteInterpreterDynamicTypeTestFixture::kNumSplits;
/**
* Enumerate all possible JIT types appearing in mobile runtime, and test
* whether subtyping relation is preserved after one of the JIT types is
@ -2065,55 +2126,14 @@ void enumerateTupleType(
* of types. We call expandTypes() twice to test types nested less or equal
* to two levels. e.g. List[Optional[Tensor]], Optional[Dict[Int, Bool]], etc.
*/
TEST(LiteInterpreterTest, DynamicType) {
auto cu = std::make_shared<CompilationUnit>();
std::vector<TypePtr> keyTypes = {
AnyType::get(),
IntType::get(),
BoolType::get(),
FloatType::get(),
ComplexType::get(),
StringType::get(),
TensorType::get(),
DeviceObjType::get(),
};
std::vector<TypePtr> types = {
NoneType::get(),
NumberType::get(),
ClassType::create("__torch__.TestClass1", cu),
ClassType::create("__torch__.TestClass2", cu),
AnyListType::get(),
AnyTupleType::get(),
StreamObjType::get(),
CapsuleType::get(),
GeneratorType::get(),
StorageType::get(),
VarType::create("t"),
VarType::create("v"),
AnyClassType::get()};
std::copy(keyTypes.begin(), keyTypes.end(), back_inserter(types));
auto expandTypes = [&](size_t tupleSize) {
std::vector<TypePtr> nested;
for (const auto& type : types) {
if (!(type == AnyType::get())) {
nested.push_back(ListType::create(type));
if (!(type == NoneType::get() || type->kind() == OptionalType::Kind)) {
nested.push_back(OptionalType::create(type));
}
}
for (const auto& keyType : keyTypes) {
nested.push_back(DictType::create(keyType, type));
}
}
std::vector<TypePtr> tmp;
enumerateTupleType(tupleSize, tmp, types, nested);
std::move(std::begin(nested), std::end(nested), std::back_inserter(types));
};
expandTypes(1);
expandTypes(1);
TEST_P(LiteInterpreterDynamicTypeTestFixture, Conformance) {
size_t num = types.size() / LiteInterpreterDynamicTypeTestFixture::kNumSplits;
size_t begin = num * GetParam();
size_t end = std::min(types.size(), begin + num);
for (const auto& a : types) {
auto da = DynamicType::create(*a);
for (const auto& b : types) {
for (size_t i = begin; i < end; i++) {
const auto& b = types[i];
bool result = a->isSubtypeOf(*b);
EXPECT_EQ(result, da->isSubtypeOf(*b));
result = b->isSubtypeOf(*a);
@ -2122,5 +2142,12 @@ TEST(LiteInterpreterTest, DynamicType) {
}
}
INSTANTIATE_TEST_CASE_P(
PyTorch,
LiteInterpreterDynamicTypeTestFixture,
::testing::Range(
static_cast<size_t>(0),
LiteInterpreterDynamicTypeTestFixture::kNumSplits));
} // namespace jit
} // namespace torch