mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add multithreading test to model compare binary (#80958)
This diff adds a option (`--nthreads`) which will launch the specified number of threads to load the models execute the correctness check on them. Differential Revision: [D37465661](https://our.internmc.facebook.com/intern/diff/D37465661/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/80958 Approved by: https://github.com/manuelcandales
This commit is contained in:
@ -74,6 +74,7 @@ C10_DEFINE_string(
|
||||
"cpu",
|
||||
"what backend to use for model (vulkan, cpu, metal) (default=cpu)");
|
||||
C10_DEFINE_string(tolerance, "1e-5", "tolerance to use for comparison");
|
||||
C10_DEFINE_int(nthreads, 1, "Number of threads to launch. Useful for checking correct concurrent behaviour.");
|
||||
C10_DEFINE_bool(
|
||||
report_failures,
|
||||
true,
|
||||
@ -232,6 +233,48 @@ std::vector<c10::IValue> create_inputs(
|
||||
return inputs;
|
||||
}
|
||||
|
||||
void run_check(float tolerance) {
|
||||
torch::jit::Module module = torch::jit::load(FLAGS_model);
|
||||
torch::jit::Module refmodule = torch::jit::load(FLAGS_refmodel);
|
||||
|
||||
module.eval();
|
||||
refmodule.eval();
|
||||
|
||||
std::thread::id this_id = std::this_thread::get_id();
|
||||
std::cout << "Running check on thread " << this_id << "." << std::endl;
|
||||
|
||||
int passed = 0;
|
||||
for (int i = 0; i < FLAGS_iter; ++i) {
|
||||
std::vector<c10::IValue> refinputs;
|
||||
std::vector<c10::IValue> inputs;
|
||||
create_inputs(
|
||||
refinputs, inputs,
|
||||
FLAGS_refbackend, FLAGS_backend,
|
||||
FLAGS_input_min, FLAGS_input_max);
|
||||
|
||||
const auto refoutput = refmodule.forward(refinputs).toTensor().cpu();
|
||||
const auto output = module.forward(inputs).toTensor().cpu();
|
||||
|
||||
bool check = checkRtol(
|
||||
refoutput-output,
|
||||
{refoutput, output},
|
||||
tolerance,
|
||||
FLAGS_report_failures);
|
||||
|
||||
if (check) {
|
||||
passed += 1;
|
||||
}
|
||||
else if (FLAGS_report_failures) {
|
||||
std::cout << " (Iteration " << i << " failed)" << std::endl;
|
||||
}
|
||||
|
||||
if (i > 0 && (i+1) % FLAGS_report_freq == 0) {
|
||||
report_pass_rate(passed, i+1);
|
||||
}
|
||||
}
|
||||
report_pass_rate(passed, FLAGS_iter);
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
c10::SetUsageMessage(
|
||||
"Run accuracy comparison to a reference model for a pytorch model.\n"
|
||||
@ -260,41 +303,24 @@ int main(int argc, char** argv) {
|
||||
c10::InferenceMode mode;
|
||||
torch::autograd::AutoGradMode guard(false);
|
||||
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard(false);
|
||||
auto module = torch::jit::load(FLAGS_model);
|
||||
auto refmodule = torch::jit::load(FLAGS_refmodel);
|
||||
|
||||
module.eval();
|
||||
refmodule.eval();
|
||||
|
||||
c10::CPUCachingAllocator caching_allocator;
|
||||
c10::optional<c10::WithCPUCachingAllocatorGuard> caching_allocator_guard;
|
||||
if (FLAGS_use_caching_allocator) {
|
||||
caching_allocator_guard.emplace(&caching_allocator);
|
||||
}
|
||||
std::cout << "Running modules." << std::endl;
|
||||
|
||||
int passed = 0;
|
||||
for (int i = 0; i < FLAGS_iter; ++i) {
|
||||
std::vector<c10::IValue> refinputs;
|
||||
std::vector<c10::IValue> inputs;
|
||||
create_inputs(refinputs, inputs, FLAGS_refbackend, FLAGS_backend, FLAGS_input_min, FLAGS_input_max);
|
||||
std::vector<std::thread> check_threads;
|
||||
check_threads.reserve(FLAGS_nthreads);
|
||||
for (int i = 0; i < FLAGS_nthreads; ++i) {
|
||||
check_threads.emplace_back(std::thread(run_check, tolerance));
|
||||
}
|
||||
|
||||
const auto refoutput = refmodule.forward(refinputs).toTensor().cpu();
|
||||
const auto output = module.forward(inputs).toTensor().cpu();
|
||||
|
||||
bool check = checkRtol(refoutput-output, {refoutput, output}, tolerance, FLAGS_report_failures);
|
||||
if (check) {
|
||||
passed += 1;
|
||||
if (FLAGS_report_failures && !check) {
|
||||
std::cout << " (Iteration " << i << " failed)" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if (i > 0 && (i+1) % FLAGS_report_freq == 0) {
|
||||
report_pass_rate(passed, i+1);
|
||||
for (std::thread& th : check_threads) {
|
||||
if (th.joinable()) {
|
||||
th.join();
|
||||
}
|
||||
}
|
||||
report_pass_rate(passed, FLAGS_iter);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
Reference in New Issue
Block a user