Files
pytorch/test/cpp/common/main.cpp
Bin Bao 4946638f06 [AOTI] Add ABI-compatiblity tests (#123848)
Summary: In AOTInductor generated CPU model code, there can be direct references to some aten/c10 utility functions and data structures, e.g. at::vec and c10::Half. These are performance critical and thus it doesn't make sense to create C shim for them. Instead, we make sure they are implemented in a header-only way, and use this set of tests to guard future changes.

There are more header files to be updated, but we will do it in other followup PRs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123848
Approved by: https://github.com/jansel
ghstack dependencies: #123847
2024-04-19 00:51:24 +00:00

34 lines
905 B
C++

#include <gtest/gtest.h>
#include <torch/cuda.h>
#include <iostream>
#include <string>
std::string add_negative_flag(const std::string& flag) {
std::string filter = ::testing::GTEST_FLAG(filter);
if (filter.find('-') == std::string::npos) {
filter.push_back('-');
} else {
filter.push_back(':');
}
filter += flag;
return filter;
}
int main(int argc, char* argv[]) {
::testing::InitGoogleTest(&argc, argv);
if (!torch::cuda::is_available()) {
std::cout << "CUDA not available. Disabling CUDA and MultiCUDA tests"
<< std::endl;
::testing::GTEST_FLAG(filter) = add_negative_flag("*_CUDA:*_MultiCUDA");
} else if (torch::cuda::device_count() < 2) {
std::cout << "Only one CUDA device detected. Disabling MultiCUDA tests"
<< std::endl;
::testing::GTEST_FLAG(filter) = add_negative_flag("*_MultiCUDA");
}
return RUN_ALL_TESTS();
}