mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Small improvements to compare_models_torch binary (#65171)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65171 Add the model comparison binary to BUCK, and also add some quality of life features such as controlling the input range. Test Plan: ``` # Build the binary cd ~/fbsource buck build -c ndk.custom_libcxx=false -c pt.enable_qpl=0 //xplat/caffe2:ptmobile_compareAndroid\#android-arm64 --show-ou # Push it to the device adb push buck-out/gen/xplat/caffe2/ptmobile_compareAndroid\#android-arm64 /data/local/tmp/compare_models # Run the benchmark binary BENCH_CMD="/data/local/tmp/compare_models" BENCH_CMD+=" --model=$PATH_TO_MODEL" BENCH_CMD+=" --refmodel=$PATH_TO_REFERENCE_MODEL" BENCH_CMD+=" --input_type=float --input_dims=$MODEL_INPUT_SIZE" BENCH_CMD+=" --iter=100" BENCH_CMD+=" --tolerance 1e-5" ``` Reviewed By: beback4u Differential Revision: D30371322 fbshipit-source-id: 5e520aaf119c90985a1d5a135f76e4057148333b
This commit is contained in:
committed by
Facebook GitHub Bot
parent
9601deb1b3
commit
f101070587
@ -14,6 +14,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <iomanip>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@ -47,6 +48,8 @@ C10_DEFINE_string(
|
||||
input_memory_format,
|
||||
"contiguous_format",
|
||||
"Input memory format (contiguous_format/channels_last)");
|
||||
C10_DEFINE_int(input_max, 1, "The maximum value inputs should have");
|
||||
C10_DEFINE_int(input_min, -1, "The minimum value inputs should have");
|
||||
C10_DEFINE_bool(
|
||||
no_inputs,
|
||||
false,
|
||||
@ -60,6 +63,7 @@ C10_DEFINE_bool(
|
||||
false,
|
||||
"Whether to print output with all one input tensor.");
|
||||
C10_DEFINE_int(iter, 10, "The number of iterations to run.");
|
||||
C10_DEFINE_int(report_freq, 1000, "An update will be reported every n iterations");
|
||||
C10_DEFINE_int(pytext_len, 0, "Length of input sequence.");
|
||||
C10_DEFINE_string(
|
||||
backend,
|
||||
@ -70,23 +74,42 @@ C10_DEFINE_string(
|
||||
"cpu",
|
||||
"what backend to use for model (vulkan, cpu, metal) (default=cpu)");
|
||||
C10_DEFINE_string(tolerance, "1e-5", "tolerance to use for comparison");
|
||||
C10_DEFINE_bool(
|
||||
report_failures,
|
||||
true,
|
||||
"Whether to report error during failed iterations");
|
||||
|
||||
bool checkRtol(
|
||||
const at::Tensor& diff,
|
||||
const std::vector<at::Tensor>& inputs,
|
||||
float tolerance) {
|
||||
float tolerance,
|
||||
bool report) {
|
||||
float maxValue = 0.0f;
|
||||
|
||||
for (const auto& tensor : inputs) {
|
||||
maxValue = fmax(tensor.abs().max().item<float>(), maxValue);
|
||||
}
|
||||
float threshold = tolerance * maxValue;
|
||||
float maxDiff = diff.abs().max().item<float>();
|
||||
|
||||
return maxDiff < (tolerance * maxValue);
|
||||
bool passed = maxDiff < threshold;
|
||||
if (!passed && report) {
|
||||
std::cout << "Check FAILED! Max diff allowed: "
|
||||
<< std::setw(10) << std::setprecision(5) << threshold
|
||||
<< " max diff: "
|
||||
<< std::setw(10) << std::setprecision(5) << maxDiff
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
bool almostEqual(const at::Tensor& a, const at::Tensor& b, float tolerance) {
|
||||
return checkRtol(a - b, {a, b}, tolerance);
|
||||
void report_pass_rate(int passed, int total) {
|
||||
int pass_rate = static_cast<int>(static_cast<float>(passed) / static_cast<float>(total) * 100);
|
||||
std::cout << "Output was equal within tolerance " << passed << "/"
|
||||
<< total
|
||||
<< " times. Pass rate: " << pass_rate
|
||||
<< std::setprecision(2) << "%" << std::endl;
|
||||
}
|
||||
|
||||
std::vector<std::string> split(
|
||||
@ -108,7 +131,9 @@ std::vector<c10::IValue> create_inputs(
|
||||
std::vector<c10::IValue>& refinputs,
|
||||
std::vector<c10::IValue>& inputs,
|
||||
std::string& refbackend,
|
||||
std::string& backend) {
|
||||
std::string& backend,
|
||||
const int range_min,
|
||||
const int range_max) {
|
||||
if (FLAGS_no_inputs) {
|
||||
return {};
|
||||
}
|
||||
@ -174,7 +199,7 @@ std::vector<c10::IValue> create_inputs(
|
||||
|
||||
const auto input_tensor = torch::rand(
|
||||
input_dims,
|
||||
at::TensorOptions(input_type).memory_format(input_memory_format));
|
||||
at::TensorOptions(input_type).memory_format(input_memory_format))*(range_max - range_min) - range_min;
|
||||
|
||||
if (refbackend == "vulkan") {
|
||||
refinputs.emplace_back(input_tensor.vulkan());
|
||||
@ -220,9 +245,17 @@ int main(int argc, char** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (FLAGS_input_min >= FLAGS_input_max) {
|
||||
std::cerr << "Input min: " << FLAGS_input_min
|
||||
<< " should be less than input max: "
|
||||
<< FLAGS_input_max << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::stringstream ss(FLAGS_tolerance);
|
||||
float tolerance = 0;
|
||||
ss >> tolerance;
|
||||
std::cout << "tolerance: " << tolerance << std::endl;
|
||||
|
||||
c10::InferenceMode mode;
|
||||
torch::autograd::AutoGradMode guard(false);
|
||||
@ -244,20 +277,24 @@ int main(int argc, char** argv) {
|
||||
for (int i = 0; i < FLAGS_iter; ++i) {
|
||||
std::vector<c10::IValue> refinputs;
|
||||
std::vector<c10::IValue> inputs;
|
||||
create_inputs(refinputs, inputs, FLAGS_refbackend, FLAGS_backend);
|
||||
create_inputs(refinputs, inputs, FLAGS_refbackend, FLAGS_backend, FLAGS_input_min, FLAGS_input_max);
|
||||
|
||||
const auto refoutput = refmodule.forward(refinputs).toTensor().cpu();
|
||||
const auto output = module.forward(inputs).toTensor().cpu();
|
||||
|
||||
bool check = almostEqual(refoutput, output, tolerance);
|
||||
bool check = checkRtol(refoutput-output, {refoutput, output}, tolerance, FLAGS_report_failures);
|
||||
if (check) {
|
||||
passed += 1;
|
||||
if (FLAGS_report_failures && !check) {
|
||||
std::cout << " (Iteration " << i << " failed)" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if (i > 0 && (i+1) % FLAGS_report_freq == 0) {
|
||||
report_pass_rate(passed, i+1);
|
||||
}
|
||||
}
|
||||
std::cout << "Output was equal within tolerance " << passed << "/"
|
||||
<< FLAGS_iter
|
||||
<< " times. Pass rate: " << (float)passed / (float)FLAGS_iter * 100
|
||||
<< std::setprecision(2) << "%" << std::endl;
|
||||
report_pass_rate(passed, FLAGS_iter);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
Reference in New Issue
Block a user