mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
[jit] Split DynamicType conformance test into smaller pieces. (#71275)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71275 Currently it's taking more than 10 minutes to run the conformance test. Instead we should use parametrized test to shard into test segments so that they can run in parallel. ghstack-source-id: 146990608 Test Plan: ``` [zhxchen17@devbig560.ftw3 /data/users/zhxchen17/fbsource/fbcode] buck test mode/dev-tsan //caffe2/test/cpp/jit:jit -- -r 'LiteInterpreterDynamicTypeTestFixture' Building... 34.9 sec (99%) 12110/12111 jobs, 0/12111 updated Tpx test run coordinator for Facebook. See https://fburl.com/tpx for details. Running with tpx session id: ebea52b3-7c7f-46be-9f69-18e2e7b040cc Trace available for this run at /tmp/tpx-20220113-113635.717778/trace.log RemoteExecution session id: reSessionID-ebea52b3-7c7f-46be-9f69-18e2e7b040cc-tpx Started reporting to test run: https://www.internalfb.com/intern/testinfra/testrun/4222124735827748 ✓ ListingSuccess: caffe2/test/cpp/jit:jit : 431 tests discovered (11.173) ✓ Pass: caffe2/test/cpp/jit:jit - Conformance/LiteInterpreterDynamicTypeTestFixture.Conformance/0 (51.331) ✓ Pass: caffe2/test/cpp/jit:jit - Conformance/LiteInterpreterDynamicTypeTestFixture.Conformance/1 (65.614) ✓ Pass: caffe2/test/cpp/jit:jit - Conformance/LiteInterpreterDynamicTypeTestFixture.Conformance/3 (76.875) ✓ Pass: caffe2/test/cpp/jit:jit - Conformance/LiteInterpreterDynamicTypeTestFixture.Conformance/5 (77.271) ✓ Pass: caffe2/test/cpp/jit:jit - Conformance/LiteInterpreterDynamicTypeTestFixture.Conformance/4 (78.871) ✓ Pass: caffe2/test/cpp/jit:jit - Conformance/LiteInterpreterDynamicTypeTestFixture.Conformance/6 (78.984) ✓ Pass: caffe2/test/cpp/jit:jit - Conformance/LiteInterpreterDynamicTypeTestFixture.Conformance/7 (84.068) ✓ Pass: caffe2/test/cpp/jit:jit - Conformance/LiteInterpreterDynamicTypeTestFixture.Conformance/2 (85.198) ✓ Pass: caffe2/test/cpp/jit:jit - Conformance/LiteInterpreterDynamicTypeTestFixture.Conformance/8 (88.815) ✓ Pass: caffe2/test/cpp/jit:jit - Conformance/LiteInterpreterDynamicTypeTestFixture.Conformance/9 (90.332) Summary Pass: 10 ListingSuccess: 1 If you need help understanding your runs, please follow the wiki: https://fburl.com/posting_in_tpx_users Finished test run: https://www.internalfb.com/intern/testinfra/testrun/4222124735827748 ``` Reviewed By: qihqi Differential Revision: D33570442 fbshipit-source-id: 5c49e03b0f88068d444c84b4adeaaf45433ce1fa
This commit is contained in:
committed by
Facebook GitHub Bot
parent
81f693d509
commit
5f2b4be3b9
@ -2055,6 +2055,67 @@ void enumerateTupleType(
|
||||
}
|
||||
}
|
||||
|
||||
class LiteInterpreterDynamicTypeTestFixture
|
||||
: public ::testing::TestWithParam<size_t> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
cu = std::make_shared<CompilationUnit>();
|
||||
std::vector<TypePtr> keyTypes = {
|
||||
AnyType::get(),
|
||||
IntType::get(),
|
||||
BoolType::get(),
|
||||
FloatType::get(),
|
||||
ComplexType::get(),
|
||||
StringType::get(),
|
||||
TensorType::get(),
|
||||
DeviceObjType::get(),
|
||||
};
|
||||
types = {
|
||||
NoneType::get(),
|
||||
NumberType::get(),
|
||||
ClassType::create("__torch__.TestClass1", cu),
|
||||
ClassType::create("__torch__.TestClass2", cu),
|
||||
AnyListType::get(),
|
||||
AnyTupleType::get(),
|
||||
StreamObjType::get(),
|
||||
CapsuleType::get(),
|
||||
GeneratorType::get(),
|
||||
StorageType::get(),
|
||||
VarType::create("t"),
|
||||
VarType::create("v"),
|
||||
AnyClassType::get()};
|
||||
std::copy(keyTypes.begin(), keyTypes.end(), back_inserter(types));
|
||||
auto expandTypes = [&](size_t tupleSize) {
|
||||
std::vector<TypePtr> nested;
|
||||
for (const auto& type : types) {
|
||||
if (!(type == AnyType::get())) {
|
||||
nested.emplace_back(ListType::create(type));
|
||||
if (!(type == NoneType::get() ||
|
||||
type->kind() == OptionalType::Kind)) {
|
||||
nested.emplace_back(OptionalType::create(type));
|
||||
}
|
||||
}
|
||||
for (const auto& keyType : keyTypes) {
|
||||
nested.emplace_back(DictType::create(keyType, type));
|
||||
}
|
||||
}
|
||||
std::vector<TypePtr> tmp;
|
||||
enumerateTupleType(tupleSize, tmp, types, nested);
|
||||
std::move(
|
||||
std::begin(nested), std::end(nested), std::back_inserter(types));
|
||||
};
|
||||
expandTypes(1);
|
||||
expandTypes(1);
|
||||
}
|
||||
std::shared_ptr<CompilationUnit> cu;
|
||||
std::vector<TypePtr> types;
|
||||
|
||||
public:
|
||||
static constexpr size_t kNumSplits = 10;
|
||||
};
|
||||
|
||||
constexpr size_t LiteInterpreterDynamicTypeTestFixture::kNumSplits;
|
||||
|
||||
/**
|
||||
* Enumerate all possible JIT types appearing in mobile runtime, and test
|
||||
* whether subtyping relation is preserved after one of the JIT types is
|
||||
@ -2065,55 +2126,14 @@ void enumerateTupleType(
|
||||
* of types. We call expandTypes() twice to test types nested less or equal
|
||||
* to two levels. e.g. List[Optional[Tensor]], Optional[Dict[Int, Bool]], etc.
|
||||
*/
|
||||
TEST(LiteInterpreterTest, DynamicType) {
|
||||
auto cu = std::make_shared<CompilationUnit>();
|
||||
std::vector<TypePtr> keyTypes = {
|
||||
AnyType::get(),
|
||||
IntType::get(),
|
||||
BoolType::get(),
|
||||
FloatType::get(),
|
||||
ComplexType::get(),
|
||||
StringType::get(),
|
||||
TensorType::get(),
|
||||
DeviceObjType::get(),
|
||||
};
|
||||
std::vector<TypePtr> types = {
|
||||
NoneType::get(),
|
||||
NumberType::get(),
|
||||
ClassType::create("__torch__.TestClass1", cu),
|
||||
ClassType::create("__torch__.TestClass2", cu),
|
||||
AnyListType::get(),
|
||||
AnyTupleType::get(),
|
||||
StreamObjType::get(),
|
||||
CapsuleType::get(),
|
||||
GeneratorType::get(),
|
||||
StorageType::get(),
|
||||
VarType::create("t"),
|
||||
VarType::create("v"),
|
||||
AnyClassType::get()};
|
||||
std::copy(keyTypes.begin(), keyTypes.end(), back_inserter(types));
|
||||
auto expandTypes = [&](size_t tupleSize) {
|
||||
std::vector<TypePtr> nested;
|
||||
for (const auto& type : types) {
|
||||
if (!(type == AnyType::get())) {
|
||||
nested.push_back(ListType::create(type));
|
||||
if (!(type == NoneType::get() || type->kind() == OptionalType::Kind)) {
|
||||
nested.push_back(OptionalType::create(type));
|
||||
}
|
||||
}
|
||||
for (const auto& keyType : keyTypes) {
|
||||
nested.push_back(DictType::create(keyType, type));
|
||||
}
|
||||
}
|
||||
std::vector<TypePtr> tmp;
|
||||
enumerateTupleType(tupleSize, tmp, types, nested);
|
||||
std::move(std::begin(nested), std::end(nested), std::back_inserter(types));
|
||||
};
|
||||
expandTypes(1);
|
||||
expandTypes(1);
|
||||
TEST_P(LiteInterpreterDynamicTypeTestFixture, Conformance) {
|
||||
size_t num = types.size() / LiteInterpreterDynamicTypeTestFixture::kNumSplits;
|
||||
size_t begin = num * GetParam();
|
||||
size_t end = std::min(types.size(), begin + num);
|
||||
for (const auto& a : types) {
|
||||
auto da = DynamicType::create(*a);
|
||||
for (const auto& b : types) {
|
||||
for (size_t i = begin; i < end; i++) {
|
||||
const auto& b = types[i];
|
||||
bool result = a->isSubtypeOf(*b);
|
||||
EXPECT_EQ(result, da->isSubtypeOf(*b));
|
||||
result = b->isSubtypeOf(*a);
|
||||
@ -2122,5 +2142,12 @@ TEST(LiteInterpreterTest, DynamicType) {
|
||||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
PyTorch,
|
||||
LiteInterpreterDynamicTypeTestFixture,
|
||||
::testing::Range(
|
||||
static_cast<size_t>(0),
|
||||
LiteInterpreterDynamicTypeTestFixture::kNumSplits));
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
Reference in New Issue
Block a user