mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Light] Remove ambiguity from compile_spec names, use actual output type (#67209)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67209 Pull Request resolved: https://github.com/pytorch/pytorch/pull/67198 Fixing a couple instances where parameters were named method_compile_spec when they were actually compile_specs that could have multiple method_compile_specs inside. Also use output dtype from buffer. Test Plan: Mobilenetv3 compiles and runs fine ``` (pytorch) ~/fbsource/fbcode/caffe2/fb/nnc └─ $ PYTORCH_JIT_LOG_LEVEL="aot_compiler" buck run //caffe2/binaries:aot_model_compiler -- --model mobilenetv3.pt --model_name=pytorch_dev_mobilenetv3 --model_version=v1 --input_dims="1,3,224,224 " Downloaded 4501/6195 artifacts, 433.89 Mbytes, 14.3% cache miss (for updated rules) Building: finished in 06:34.6 min (100%) 20233/20233 jobs, 5467/20233 updated Total time: 06:35.0 min BUILD SUCCEEDED The compiled llvm assembly code was saved to mobilenetv3.compiled.ll The compiled model was saved to mobilenetv3.compiled.pt └─ $ ./compile_model.sh -m pytorch_dev_mobilenetv3 -p /data/users/priyaramani/fbsource/fbcode/caffe2/fb/nnc/mobilenetv3.pt -v v1 -i "1,3,224,224" + VERSION=v1 + getopts m:p:v:i:h opt + case $opt in + MODEL=pytorch_dev_mobilenetv3 . . Columns 961 to 9701e-11 * -4.2304 -3.9674 2.4473 -0.8664 -0.7513 1.2140 0.0010 3.8675 1.2714 2.2989 Columns 971 to 9801e-11 * -2.7203 1.6772 -0.7460 -0.6936 4.4421 -0.9865 -0.5186 -1.4441 1.3047 -1.6112 Columns 981 to 9901e-11 * 0.1275 -1.8815 2.5105 -0.4871 -2.2342 0.8520 0.8658 1.6180 3.8901 -0.2454 Columns 991 to 10001e-11 * -1.4896 4.1337 -2.6640 0.8226 0.2441 -1.4830 -1.7430 1.8758 0.5481 0.5093 [ CPUFloatType{1,1000} ] Starting benchmark. Running warmup runs. Main runs. Main run finished. Milliseconds per iter: 276.255. Iters per second: 3.61984 Memory usage before main runs: 104366080 bytes Memory usage after main runs: 343441408 bytes Average memory increase per iter: 2.39075e+07 bytes 0 value means "not available" in above ``` Reviewed By: ljk53 Differential Revision: D31698338 fbshipit-source-id: da6c74c1321ec02e0652f3afe6f97bf789d3361b
This commit is contained in:
committed by
Facebook GitHub Bot
parent
ad5731cacc
commit
ecf7e96969
@ -73,8 +73,8 @@ c10::Dict<c10::IValue, c10::IValue> createCompileSpec() {
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> getInputSizes (
|
||||
const c10::Dict<c10::IValue, c10::IValue>& method_compile_spec) {
|
||||
auto input_shapes = method_compile_spec.at(FLAGS_method_name).toGenericDict().at("sizes").toList();
|
||||
const c10::Dict<c10::IValue, c10::IValue>& compile_spec) {
|
||||
auto input_shapes = compile_spec.at(FLAGS_method_name).toGenericDict().at("sizes").toList();
|
||||
std::vector<std::vector<int64_t>> inputSizes;
|
||||
for (const auto& input_shape : input_shapes) {
|
||||
auto sizes = ((c10::IValue) input_shape).toIntVector();
|
||||
@ -105,7 +105,7 @@ void writeOutputLlvmAssembly(const std::string& asm_code) {
|
||||
|
||||
c10::IValue preprocess(
|
||||
const torch::jit::Module& mod,
|
||||
const c10::Dict<c10::IValue, c10::IValue>& method_compile_spec,
|
||||
const c10::Dict<c10::IValue, c10::IValue>& compile_spec,
|
||||
const torch::jit::BackendDebugHandleGenerator& generate_debug_handles) {
|
||||
|
||||
std::string output_llvm_file_name = FLAGS_output_llvm;
|
||||
@ -116,7 +116,7 @@ c10::IValue preprocess(
|
||||
|
||||
auto method = mod.get_method(FLAGS_method_name);
|
||||
auto graph = method.function().graph()->copy();
|
||||
auto sizes = getInputSizes(method_compile_spec);
|
||||
auto sizes = getInputSizes(compile_spec);
|
||||
|
||||
std::string llvm_asm_code;
|
||||
auto compiled = torch::jit::mobile::nnc::aotCompile(FLAGS_method_name, graph, sizes);
|
||||
|
@ -39,6 +39,7 @@ std::vector<mobile::nnc::InputSpec> toInputSpecs(
|
||||
for (const auto& sizes : inputSizes) {
|
||||
mobile::nnc::InputSpec spec;
|
||||
spec.sizes_ = sizes;
|
||||
// TODO: get and set input dtype
|
||||
spec.dtype_ = c10::ScalarType::Float;
|
||||
specs.emplace_back(std::move(spec));
|
||||
}
|
||||
@ -75,8 +76,7 @@ std::unique_ptr<Function> compileMethod(
|
||||
OutputSpec output;
|
||||
output.sizes_ = getConstSizes(ba.buf());
|
||||
// TODO: assert the output is a buffer and not a scalar
|
||||
// TODO: use actual dtype
|
||||
output.dtype_ = c10::ScalarType::Float;
|
||||
output.dtype_ = ba.buf()->dtype().scalar_type();
|
||||
out_spec.push_back(output);
|
||||
}
|
||||
func->set_output_specs(out_spec);
|
||||
|
Reference in New Issue
Block a user