Add multithreading test to model compare binary (#80958)

This diff adds a option (`--nthreads`) which will launch the specified number of threads to load the models execute the correctness check on them.

Differential Revision: [D37465661](https://our.internmc.facebook.com/intern/diff/D37465661/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80958
Approved by: https://github.com/manuelcandales
This commit is contained in:
ssjia
2022-07-06 07:57:28 -07:00
committed by PyTorch MergeBot
parent ec594dd305
commit cb630c775e

View File

@ -74,6 +74,7 @@ C10_DEFINE_string(
"cpu",
"what backend to use for model (vulkan, cpu, metal) (default=cpu)");
C10_DEFINE_string(tolerance, "1e-5", "tolerance to use for comparison");
C10_DEFINE_int(nthreads, 1, "Number of threads to launch. Useful for checking correct concurrent behaviour.");
C10_DEFINE_bool(
report_failures,
true,
@ -232,6 +233,48 @@ std::vector<c10::IValue> create_inputs(
return inputs;
}
void run_check(float tolerance) {
torch::jit::Module module = torch::jit::load(FLAGS_model);
torch::jit::Module refmodule = torch::jit::load(FLAGS_refmodel);
module.eval();
refmodule.eval();
std::thread::id this_id = std::this_thread::get_id();
std::cout << "Running check on thread " << this_id << "." << std::endl;
int passed = 0;
for (int i = 0; i < FLAGS_iter; ++i) {
std::vector<c10::IValue> refinputs;
std::vector<c10::IValue> inputs;
create_inputs(
refinputs, inputs,
FLAGS_refbackend, FLAGS_backend,
FLAGS_input_min, FLAGS_input_max);
const auto refoutput = refmodule.forward(refinputs).toTensor().cpu();
const auto output = module.forward(inputs).toTensor().cpu();
bool check = checkRtol(
refoutput-output,
{refoutput, output},
tolerance,
FLAGS_report_failures);
if (check) {
passed += 1;
}
else if (FLAGS_report_failures) {
std::cout << " (Iteration " << i << " failed)" << std::endl;
}
if (i > 0 && (i+1) % FLAGS_report_freq == 0) {
report_pass_rate(passed, i+1);
}
}
report_pass_rate(passed, FLAGS_iter);
}
int main(int argc, char** argv) {
c10::SetUsageMessage(
"Run accuracy comparison to a reference model for a pytorch model.\n"
@ -260,41 +303,24 @@ int main(int argc, char** argv) {
c10::InferenceMode mode;
torch::autograd::AutoGradMode guard(false);
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard(false);
auto module = torch::jit::load(FLAGS_model);
auto refmodule = torch::jit::load(FLAGS_refmodel);
module.eval();
refmodule.eval();
c10::CPUCachingAllocator caching_allocator;
c10::optional<c10::WithCPUCachingAllocatorGuard> caching_allocator_guard;
if (FLAGS_use_caching_allocator) {
caching_allocator_guard.emplace(&caching_allocator);
}
std::cout << "Running modules." << std::endl;
int passed = 0;
for (int i = 0; i < FLAGS_iter; ++i) {
std::vector<c10::IValue> refinputs;
std::vector<c10::IValue> inputs;
create_inputs(refinputs, inputs, FLAGS_refbackend, FLAGS_backend, FLAGS_input_min, FLAGS_input_max);
std::vector<std::thread> check_threads;
check_threads.reserve(FLAGS_nthreads);
for (int i = 0; i < FLAGS_nthreads; ++i) {
check_threads.emplace_back(std::thread(run_check, tolerance));
}
const auto refoutput = refmodule.forward(refinputs).toTensor().cpu();
const auto output = module.forward(inputs).toTensor().cpu();
bool check = checkRtol(refoutput-output, {refoutput, output}, tolerance, FLAGS_report_failures);
if (check) {
passed += 1;
if (FLAGS_report_failures && !check) {
std::cout << " (Iteration " << i << " failed)" << std::endl;
}
}
if (i > 0 && (i+1) % FLAGS_report_freq == 0) {
report_pass_rate(passed, i+1);
for (std::thread& th : check_threads) {
if (th.joinable()) {
th.join();
}
}
report_pass_rate(passed, FLAGS_iter);
return 0;
}