diff --git a/binaries/compare_models_torch.cc b/binaries/compare_models_torch.cc index bf88c390799f..7afac42589b6 100644 --- a/binaries/compare_models_torch.cc +++ b/binaries/compare_models_torch.cc @@ -74,6 +74,7 @@ C10_DEFINE_string( "cpu", "what backend to use for model (vulkan, cpu, metal) (default=cpu)"); C10_DEFINE_string(tolerance, "1e-5", "tolerance to use for comparison"); +C10_DEFINE_int(nthreads, 1, "Number of threads to launch. Useful for checking correct concurrent behaviour."); C10_DEFINE_bool( report_failures, true, @@ -232,6 +233,48 @@ std::vector create_inputs( return inputs; } +void run_check(float tolerance) { + torch::jit::Module module = torch::jit::load(FLAGS_model); + torch::jit::Module refmodule = torch::jit::load(FLAGS_refmodel); + + module.eval(); + refmodule.eval(); + + std::thread::id this_id = std::this_thread::get_id(); + std::cout << "Running check on thread " << this_id << "." << std::endl; + + int passed = 0; + for (int i = 0; i < FLAGS_iter; ++i) { + std::vector refinputs; + std::vector inputs; + create_inputs( + refinputs, inputs, + FLAGS_refbackend, FLAGS_backend, + FLAGS_input_min, FLAGS_input_max); + + const auto refoutput = refmodule.forward(refinputs).toTensor().cpu(); + const auto output = module.forward(inputs).toTensor().cpu(); + + bool check = checkRtol( + refoutput-output, + {refoutput, output}, + tolerance, + FLAGS_report_failures); + + if (check) { + passed += 1; + } + else if (FLAGS_report_failures) { + std::cout << " (Iteration " << i << " failed)" << std::endl; + } + + if (i > 0 && (i+1) % FLAGS_report_freq == 0) { + report_pass_rate(passed, i+1); + } + } + report_pass_rate(passed, FLAGS_iter); +} + int main(int argc, char** argv) { c10::SetUsageMessage( "Run accuracy comparison to a reference model for a pytorch model.\n" @@ -260,41 +303,24 @@ int main(int argc, char** argv) { c10::InferenceMode mode; torch::autograd::AutoGradMode guard(false); torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard(false); - auto module = torch::jit::load(FLAGS_model); - auto refmodule = torch::jit::load(FLAGS_refmodel); - - module.eval(); - refmodule.eval(); c10::CPUCachingAllocator caching_allocator; c10::optional caching_allocator_guard; if (FLAGS_use_caching_allocator) { caching_allocator_guard.emplace(&caching_allocator); } - std::cout << "Running modules." << std::endl; - int passed = 0; - for (int i = 0; i < FLAGS_iter; ++i) { - std::vector refinputs; - std::vector inputs; - create_inputs(refinputs, inputs, FLAGS_refbackend, FLAGS_backend, FLAGS_input_min, FLAGS_input_max); + std::vector check_threads; + check_threads.reserve(FLAGS_nthreads); + for (int i = 0; i < FLAGS_nthreads; ++i) { + check_threads.emplace_back(std::thread(run_check, tolerance)); + } - const auto refoutput = refmodule.forward(refinputs).toTensor().cpu(); - const auto output = module.forward(inputs).toTensor().cpu(); - - bool check = checkRtol(refoutput-output, {refoutput, output}, tolerance, FLAGS_report_failures); - if (check) { - passed += 1; - if (FLAGS_report_failures && !check) { - std::cout << " (Iteration " << i << " failed)" << std::endl; - } - } - - if (i > 0 && (i+1) % FLAGS_report_freq == 0) { - report_pass_rate(passed, i+1); + for (std::thread& th : check_threads) { + if (th.joinable()) { + th.join(); } } - report_pass_rate(passed, FLAGS_iter); return 0; }