[Light] Remove ambiguity from compile_spec names, use actual output type (#67209)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67209

Pull Request resolved: https://github.com/pytorch/pytorch/pull/67198

Fixing a couple instances where parameters were named method_compile_spec when they were actually compile_specs that could have multiple method_compile_specs inside.
Also use output dtype from buffer.

Test Plan:
Mobilenetv3 compiles and runs fine
```
(pytorch)  ~/fbsource/fbcode/caffe2/fb/nnc
└─ $ PYTORCH_JIT_LOG_LEVEL="aot_compiler" buck run //caffe2/binaries:aot_model_compiler -- --model mobilenetv3.pt --model_name=pytorch_dev_mobilenetv3 --model_version=v1 --input_dims="1,3,224,224
"
Downloaded 4501/6195 artifacts, 433.89 Mbytes, 14.3% cache miss (for updated rules)
Building: finished in 06:34.6 min (100%) 20233/20233 jobs, 5467/20233 updated
  Total time: 06:35.0 min
BUILD SUCCEEDED
The compiled llvm assembly code was saved to mobilenetv3.compiled.ll
The compiled model was saved to mobilenetv3.compiled.pt

└─ $ ./compile_model.sh -m pytorch_dev_mobilenetv3 -p /data/users/priyaramani/fbsource/fbcode/caffe2/fb/nnc/mobilenetv3.pt -v v1 -i "1,3,224,224"
+ VERSION=v1
+ getopts m:p:v:i:h opt
+ case $opt in
+ MODEL=pytorch_dev_mobilenetv3
.
.
Columns 961 to 9701e-11 *
-4.2304 -3.9674  2.4473 -0.8664 -0.7513  1.2140  0.0010  3.8675  1.2714  2.2989

Columns 971 to 9801e-11 *
-2.7203  1.6772 -0.7460 -0.6936  4.4421 -0.9865 -0.5186 -1.4441  1.3047 -1.6112

Columns 981 to 9901e-11 *
 0.1275 -1.8815  2.5105 -0.4871 -2.2342  0.8520  0.8658  1.6180  3.8901 -0.2454

Columns 991 to 10001e-11 *
-1.4896  4.1337 -2.6640  0.8226  0.2441 -1.4830 -1.7430  1.8758  0.5481  0.5093
[ CPUFloatType{1,1000} ]
Starting benchmark.
Running warmup runs.
Main runs.
Main run finished. Milliseconds per iter: 276.255. Iters per second: 3.61984
Memory usage before main runs: 104366080 bytes
Memory usage after main runs: 343441408 bytes
Average memory increase per iter: 2.39075e+07 bytes
0 value means "not available" in above
```

Reviewed By: ljk53

Differential Revision: D31698338

fbshipit-source-id: da6c74c1321ec02e0652f3afe6f97bf789d3361b
This commit is contained in:
Priya Ramani
2021-10-25 17:42:38 -07:00
committed by Facebook GitHub Bot
parent ad5731cacc
commit ecf7e96969
2 changed files with 6 additions and 6 deletions

View File

@ -73,8 +73,8 @@ c10::Dict<c10::IValue, c10::IValue> createCompileSpec() {
}
std::vector<std::vector<int64_t>> getInputSizes (
const c10::Dict<c10::IValue, c10::IValue>& method_compile_spec) {
auto input_shapes = method_compile_spec.at(FLAGS_method_name).toGenericDict().at("sizes").toList();
const c10::Dict<c10::IValue, c10::IValue>& compile_spec) {
auto input_shapes = compile_spec.at(FLAGS_method_name).toGenericDict().at("sizes").toList();
std::vector<std::vector<int64_t>> inputSizes;
for (const auto& input_shape : input_shapes) {
auto sizes = ((c10::IValue) input_shape).toIntVector();
@ -105,7 +105,7 @@ void writeOutputLlvmAssembly(const std::string& asm_code) {
c10::IValue preprocess(
const torch::jit::Module& mod,
const c10::Dict<c10::IValue, c10::IValue>& method_compile_spec,
const c10::Dict<c10::IValue, c10::IValue>& compile_spec,
const torch::jit::BackendDebugHandleGenerator& generate_debug_handles) {
std::string output_llvm_file_name = FLAGS_output_llvm;
@ -116,7 +116,7 @@ c10::IValue preprocess(
auto method = mod.get_method(FLAGS_method_name);
auto graph = method.function().graph()->copy();
auto sizes = getInputSizes(method_compile_spec);
auto sizes = getInputSizes(compile_spec);
std::string llvm_asm_code;
auto compiled = torch::jit::mobile::nnc::aotCompile(FLAGS_method_name, graph, sizes);

View File

@ -39,6 +39,7 @@ std::vector<mobile::nnc::InputSpec> toInputSpecs(
for (const auto& sizes : inputSizes) {
mobile::nnc::InputSpec spec;
spec.sizes_ = sizes;
// TODO: get and set input dtype
spec.dtype_ = c10::ScalarType::Float;
specs.emplace_back(std::move(spec));
}
@ -75,8 +76,7 @@ std::unique_ptr<Function> compileMethod(
OutputSpec output;
output.sizes_ = getConstSizes(ba.buf());
// TODO: assert the output is a buffer and not a scalar
// TODO: use actual dtype
output.dtype_ = c10::ScalarType::Float;
output.dtype_ = ba.buf()->dtype().scalar_type();
out_spec.push_back(output);
}
func->set_output_specs(out_spec);